Browse Source

[9.1] [ML] Restore delayed string copying (#135242) (#135464)

This commit restores the behaviour introduced in #125837 which was
inadvertently undone by changes in #121041, specifically, delaying
copying Strings as part of calling Request.chunkText() until the request
is being executed.

In addition to the above change, refactor doChunkedInfer() and its
implementations to take a List<ChunkInferenceInput> rather than
EmbeddingsInput, since the EmbeddingsInput passed into doChunkedInfer()
was immediately discarded after extracting the ChunkInferenceInput list
from it. This change allowed the EmbeddingsInput class to be refactored
to not know about ChunkInferenceInput, simplifying it significantly.

This commit also simplifies EmbeddingRequestChunker.Request to take only
the input String rather than the entire list of all inputs, since only
one input is actually needed. This change prevents Requests from
retaining a reference to the input list, potentially allowing it to be
GC'd faster.

(cherry picked from commit f3447d3d45d5caa96f2e3b3f6a124019b96178e6)
Donal Evans 3 weeks ago
parent
commit
1db5f0d512
69 changed files with 291 additions and 297 deletions
  1. 7 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java
  2. 22 38
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java
  3. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java
  4. 3 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
  5. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java
  6. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
  7. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java
  8. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java
  9. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
  10. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
  11. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java
  12. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
  13. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java
  14. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
  15. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
  16. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java
  17. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java
  18. 5 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
  19. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java
  20. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java
  21. 7 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
  22. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java
  23. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java
  24. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java
  25. 4 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
  26. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java
  27. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
  28. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java
  29. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
  30. 17 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
  31. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java
  32. 8 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
  33. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java
  34. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
  35. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java
  36. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java
  37. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
  38. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java
  39. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java
  40. 0 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java
  41. 41 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java
  42. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
  43. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java
  44. 5 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java
  45. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java
  46. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java
  47. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java
  48. 4 5
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java
  49. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java
  50. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
  51. 6 10
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java
  52. 5 26
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java
  53. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java
  54. 7 7
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java
  55. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java
  56. 3 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
  57. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java
  58. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
  59. 8 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
  60. 4 4
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java
  61. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java
  62. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java
  63. 6 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java
  64. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java
  65. 4 4
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
  66. 7 7
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
  67. 6 7
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java
  68. 1 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java
  69. 7 20
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java

+ 7 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

@@ -37,16 +37,16 @@ import java.util.stream.Collectors;
  * a single large input that has been chunked may spread over
  * multiple batches.
  *
- * The final aspect it to gather the responses from the batch
+ * The final aspect is to gather the responses from the batch
  * processing and map the results back to the original element
  * in the input list.
  */
 public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
 
     // Visible for testing
-    record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
+    record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
         public String chunkText() {
-            return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
+            return input.substring(chunk.start(), chunk.end());
         }
     }
 
@@ -60,7 +60,7 @@ public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
 
     private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
 
-    // The maximum number of chunks that is stored for any input text.
+    // The maximum number of chunks that are stored for any input text.
     // If the configured chunker chunks the text into more chunks, each
     // chunk is sent to the inference service separately, but the results
     // are merged so that only this maximum number of chunks is stored.
@@ -112,7 +112,8 @@ public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
                 chunkingSettings = defaultChunkingSettings;
             }
             Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
-            List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
+            String inputString = inputs.get(inputIndex).input();
+            List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
             int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
             resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
             resultOffsetStarts.add(new ArrayList<>(resultCount));
@@ -129,7 +130,7 @@ public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
                 } else {
                     resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
                 }
-                allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
+                allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
             }
         }
 

+ 22 - 38
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

@@ -8,63 +8,47 @@
 package org.elasticsearch.xpack.inference.external.http.sender;
 
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.inference.ChunkInferenceInput;
-import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InputType;
 
 import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Supplier;
-import java.util.stream.Collectors;
 
 public class EmbeddingsInput extends InferenceInputs {
-
-    public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
-        if (inferenceInputs instanceof EmbeddingsInput == false) {
-            throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
-        }
-
-        return (EmbeddingsInput) inferenceInputs;
-    }
-
-    private final Supplier<List<ChunkInferenceInput>> listSupplier;
+    private final Supplier<List<String>> inputListSupplier;
     private final InputType inputType;
+    private final AtomicBoolean supplierInvoked = new AtomicBoolean();
 
-    public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
-        super(false);
-        this.listSupplier = Objects.requireNonNull(inputSupplier);
-        this.inputType = inputType;
+    public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
+        this(() -> input, inputType, false);
     }
 
-    public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
-        this(input, chunkingSettings, inputType, false);
+    public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
+        this(() -> input, inputType, stream);
     }
 
-    public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
-        this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
+    public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
+        this(inputSupplier, inputType, false);
     }
 
-    public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
-        this(input, inputType, false);
-    }
-
-    public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
+    private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
         super(stream);
-        Objects.requireNonNull(input);
-        this.listSupplier = () -> input;
+        this.inputListSupplier = Objects.requireNonNull(inputSupplier);
         this.inputType = inputType;
     }
 
-    public List<ChunkInferenceInput> getInputs() {
-        return this.listSupplier.get();
-    }
-
-    public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
-        return new EmbeddingsInput(input, null, inputType);
-    }
-
-    public List<String> getStringInputs() {
-        return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
+    /**
+     * Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply
+     * returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input
+     * Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling
+     * this method twice in a non-production environment will cause an {@link AssertionError} to be thrown.
+     *
+     * @return a list of String embedding inputs
+     */
+    public List<String> getInputs() {
+        assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice";
+        return inputListSupplier.get();
     }
 
     public InputType getInputType() {

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java

@@ -52,7 +52,7 @@ public class TruncatingRequestManager extends BaseRequestManager {
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
+        var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
         var truncatedInput = truncate(docsInput, maxInputTokens);
         var request = requestCreator.apply(truncatedInput);
 

+ 3 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

@@ -101,7 +101,7 @@ public abstract class SenderService implements InferenceService {
                 if (validationException.validationErrors().isEmpty() == false) {
                     throw validationException;
                 }
-                yield new EmbeddingsInput(input, null, inputType, stream);
+                yield new EmbeddingsInput(input, inputType, stream);
             }
             default -> throw new ElasticsearchStatusException(
                 Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
@@ -140,7 +140,7 @@ public abstract class SenderService implements InferenceService {
         }
 
         // a non-null query is not supported and is dropped by all providers
-        doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
+        doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
     }
 
     protected abstract void doInfer(
@@ -164,7 +164,7 @@ public abstract class SenderService implements InferenceService {
 
     protected abstract void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java

@@ -71,8 +71,9 @@ public class AlibabaCloudSearchEmbeddingsRequestManager extends AlibabaCloudSear
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -321,7 +322,7 @@ public class AlibabaCloudSearchService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -336,14 +337,14 @@ public class AlibabaCloudSearchService extends SenderService {
         var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java

@@ -71,8 +71,8 @@ public class AlibabaCloudSearchSparseRequestManager extends AlibabaCloudSearchRe
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java

@@ -56,8 +56,8 @@ public class AmazonBedrockEmbeddingsRequestManager extends AmazonBedrockRequestM
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         var serviceSettings = embeddingsModel.getServiceSettings();

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -136,7 +137,7 @@ public class AmazonBedrockService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -147,14 +148,14 @@ public class AmazonBedrockService extends SenderService {
             var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
 
             List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-                inputs.getInputs(),
+                inputs,
                 maxBatchSize,
                 baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
             ).batchRequestsWithListeners(listener);
 
             for (var request : batchedRequests) {
                 var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
-                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+                action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
             }
         } else {
             listener.onFailure(createInvalidModelException(model));

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -26,7 +27,6 @@ import org.elasticsearch.inference.SettingsConfiguration;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -222,7 +222,7 @@ public class AnthropicService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java

@@ -50,8 +50,8 @@ public class AzureAiStudioEmbeddingsRequestManager extends AzureAiStudioRequestM
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -123,7 +124,7 @@ public class AzureAiStudioService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -133,14 +134,14 @@ public class AzureAiStudioService extends SenderService {
             var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
 
             List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-                inputs.getInputs(),
+                inputs,
                 EMBEDDING_MAX_BATCH_SIZE,
                 baseAzureAiStudioModel.getConfigurations().getChunkingSettings()
             ).batchRequestsWithListeners(listener);
 
             for (var request : batchedRequests) {
                 var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
-                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+                action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
             }
         } else {
             listener.onFailure(createInvalidModelException(model));

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java

@@ -63,8 +63,8 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -272,7 +273,7 @@ public class AzureOpenAiService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -286,14 +287,14 @@ public class AzureOpenAiService extends SenderService {
         var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             azureOpenAiModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = azureOpenAiModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -274,7 +275,7 @@ public class CohereService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -289,14 +290,14 @@ public class CohereService extends SenderService {
         var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             cohereModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = cohereModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java

@@ -81,8 +81,8 @@ public class CohereActionCreator implements CohereActionVisitor {
                 : overriddenModel.getTaskSettings().getInputType();
 
             return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) {
-                case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel);
-                case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel);
+                case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel);
+                case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel);
             };
         };
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java

@@ -75,7 +75,7 @@ public class CustomRequestManager extends BaseRequestManager {
             requestParameters = CompletionParameters.of(chatInputs);
         } else if (inferenceInputs instanceof EmbeddingsInput) {
             requestParameters = EmbeddingParameters.of(
-                EmbeddingsInput.of(inferenceInputs),
+                inferenceInputs.castTo(EmbeddingsInput.class),
                 model.getServiceSettings().getInputTypeTranslator()
             );
         } else {

+ 5 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -136,7 +137,7 @@ public class CustomService extends SenderService {
             case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input")));
             case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input")));
             case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of(
-                new EmbeddingsInput(List.of("test input"), null, null),
+                new EmbeddingsInput(List.of("test input"), null),
                 model.getServiceSettings().getInputTypeTranslator()
             );
             default -> throw new IllegalStateException(
@@ -280,7 +281,7 @@ public class CustomService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -298,14 +299,14 @@ public class CustomService extends SenderService {
         var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             customModel.getServiceSettings().getBatchSize(),
             customModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java

@@ -28,7 +28,7 @@ public class EmbeddingParameters extends RequestParameters {
     private final InputTypeTranslator translator;
 
     private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) {
-        super(embeddingsInput.getStringInputs());
+        super(embeddingsInput.getInputs());
         this.inputType = embeddingsInput.getInputType();
         this.translator = translator;
     }

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -25,7 +26,6 @@ import org.elasticsearch.inference.SettingsConfiguration;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -106,7 +106,7 @@ public class DeepSeekService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,

+ 7 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySecretSettings;
@@ -351,7 +352,7 @@ public class ElasticInferenceService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -361,14 +362,14 @@ public class ElasticInferenceService extends SenderService {
             var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
 
             List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-                inputs.getInputs(),
+                inputs,
                 DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE,
                 denseTextEmbeddingsModel.getConfigurations().getChunkingSettings()
             ).batchRequestsWithListeners(listener);
 
             for (var request : batchedRequests) {
                 var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings);
-                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+                action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
             }
 
             return;
@@ -378,14 +379,14 @@ public class ElasticInferenceService extends SenderService {
             var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
 
             List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-                inputs.getInputs(),
+                inputs,
                 SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE,
                 model.getConfigurations().getChunkingSettings()
             ).batchRequestsWithListeners(listener);
 
             for (var request : batchedRequests) {
                 var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings);
-                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+                action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
             }
 
             return;
@@ -618,7 +619,7 @@ public class ElasticInferenceService extends SenderService {
 
     private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
         if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
-            var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs();
+            var inputsAsList = inputs.castTo(EmbeddingsInput.class).getInputs();
             return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults);
         } else if (inferenceResults instanceof ErrorInferenceResults error) {
             return List.of(new ChunkedInferenceError(error.getException()));

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java

@@ -69,8 +69,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java

@@ -97,7 +97,7 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
             DENSE_TEXT_EMBEDDINGS_HANDLER,
             (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest(
                 model,
-                embeddingsInput.getStringInputs(),
+                embeddingsInput.getInputs(),
                 traceContext,
                 extractRequestMetadataFromThreadContext(threadPool.getThreadContext()),
                 embeddingsInput.getInputType()

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java

@@ -56,8 +56,8 @@ public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioReques
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

+ 4 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -340,7 +341,7 @@ public class GoogleAiStudioService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -349,19 +350,13 @@ public class GoogleAiStudioService extends SenderService {
         GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             googleAiStudioModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
-            doInfer(
-                model,
-                EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType),
-                taskSettings,
-                timeout,
-                request.listener()
-            );
+            doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener());
         }
     }
 

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java

@@ -64,8 +64,8 @@ public class GoogleVertexAiEmbeddingsRequestManager extends GoogleVertexAiReques
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -264,7 +265,7 @@ public class GoogleVertexAiService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -274,14 +275,14 @@ public class GoogleVertexAiService extends SenderService {
         var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             googleVertexAiModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = googleVertexAiModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java

@@ -62,8 +62,8 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
-        var truncatedInput = truncate(docsInput, model.getTokenLimit());
+        List<String> inputs = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
+        var truncatedInput = truncate(inputs, model.getTokenLimit());
         var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
 
         execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

@@ -13,6 +13,7 @@ import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -140,7 +141,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -155,14 +156,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
         var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             huggingFaceModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = huggingFaceModel.accept(actionCreator);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 17 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java

@@ -94,7 +94,7 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -105,22 +105,31 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
         );
 
         // TODO chunking sparse embeddings not implemented
-        doInfer(model, inputs, taskSettings, timeout, inferListener);
+        doInfer(
+            model,
+            new EmbeddingsInput(inputs.stream().map(ChunkInferenceInput::input).toList(), inputType),
+            taskSettings,
+            timeout,
+            inferListener
+        );
     }
 
-    private static List<ChunkedInference> translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) {
+    private static List<ChunkedInference> translateToChunkedResults(
+        List<ChunkInferenceInput> inputs,
+        InferenceServiceResults inferenceResults
+    ) {
         if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) {
-            validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size());
+            validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size());
 
-            var results = new ArrayList<ChunkedInference>(inputs.getInputs().size());
+            var results = new ArrayList<ChunkedInference>(inputs.size());
 
-            for (int i = 0; i < inputs.getInputs().size(); i++) {
+            for (int i = 0; i < inputs.size(); i++) {
                 results.add(
                     new ChunkedInferenceEmbedding(
                         List.of(
                             new EmbeddingResults.Chunk(
                                 textEmbeddingResults.embeddings().get(i),
-                                new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length())
+                                new ChunkedInference.TextOffset(0, inputs.get(i).input().length())
                             )
                         )
                     )
@@ -128,7 +137,7 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
             }
             return results;
         } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
-            var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs());
+            var inputsAsList = ChunkInferenceInput.inputs(inputs);
             return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults);
         } else if (inferenceResults instanceof ErrorInferenceResults error) {
             return List.of(new ChunkedInferenceError(error.getException()));

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java

@@ -55,7 +55,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
+        List<String> docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
 
         execute(

+ 8 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -297,7 +298,7 @@ public class IbmWatsonxService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput input,
+        List<ChunkInferenceInput> input,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -306,13 +307,16 @@ public class IbmWatsonxService extends SenderService {
         IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model;
 
         var batchedRequests = new EmbeddingRequestChunker<>(
-            input.getInputs(),
+            input,
             EMBEDDING_MAX_BATCH_SIZE,
             ibmWatsonxModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
+
+        var actionCreator = getActionCreator(getSender(), getServiceComponents());
+
         for (var request : batchedRequests) {
-            var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            var action = ibmWatsonxModel.accept(actionCreator, taskSettings);
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java

@@ -52,8 +52,8 @@ public class JinaAIEmbeddingsRequestManager extends JinaAIRequestManager {
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
-        List<String> docsInput = input.getStringInputs();
+        EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
+        List<String> docsInput = input.getInputs();
         InputType inputType = input.getInputType();
 
         JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model);

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -265,7 +266,7 @@ public class JinaAIService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -280,14 +281,14 @@ public class JinaAIService extends SenderService {
         var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             jinaaiModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = jinaaiModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java

@@ -61,7 +61,7 @@ public class MistralEmbeddingsRequestManager extends BaseRequestManager {
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
+        List<String> docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
         var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
         MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model);
 

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -146,7 +147,7 @@ public class MistralService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -156,14 +157,14 @@ public class MistralService extends SenderService {
 
         if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
             List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-                inputs.getInputs(),
+                inputs,
                 MistralConstants.MAX_BATCH_SIZE,
                 mistralEmbeddingsModel.getConfigurations().getChunkingSettings()
             ).batchRequestsWithListeners(listener);
 
             for (var request : batchedRequests) {
                 var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
-                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+                action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
             }
         } else {
             listener.onFailure(createInvalidModelException(model));

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -325,7 +326,7 @@ public class OpenAiService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -340,14 +341,14 @@ public class OpenAiService extends SenderService {
         var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             EMBEDDING_MAX_BATCH_SIZE,
             openAiModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = openAiModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -285,7 +286,7 @@ public class VoyageAIService extends SenderService {
     @Override
     protected void doChunkedInfer(
         Model model,
-        EmbeddingsInput inputs,
+        List<ChunkInferenceInput> inputs,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -300,14 +301,14 @@ public class VoyageAIService extends SenderService {
         var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
 
         List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
-            inputs.getInputs(),
+            inputs,
             getBatchSize(voyageaiModel),
             voyageaiModel.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
             var action = voyageaiModel.accept(actionCreator, taskSettings);
-            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
         }
     }
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java

@@ -57,7 +57,7 @@ public class VoyageAIActionCreator implements VoyageAIActionVisitor {
             overriddenModel,
             EMBEDDINGS_HANDLER,
             (embeddingsInput) -> new VoyageAIEmbeddingsRequest(
-                embeddingsInput.getStringInputs(),
+                embeddingsInput.getInputs(),
                 embeddingsInput.getInputType(),
                 overriddenModel
             ),

+ 0 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.external.action;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
@@ -65,8 +64,6 @@ public class SingleInputSenderExecutableActionTests extends ESTestCase {
 
     public void testMoreThanOneInput() {
         var badInput = mock(EmbeddingsInput.class);
-        var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two"));
-        when(badInput.getInputs()).thenReturn(input);
         when(badInput.isSingleInput()).thenReturn(false);
         var actualException = new AtomicReference<Exception>();
 

+ 41 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java

@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
+
+import static org.hamcrest.Matchers.is;
+
+public class EmbeddingsInputTests extends ESTestCase {
+    public void testCallingGetInputs_invokesSupplier() {
+        AtomicBoolean invoked = new AtomicBoolean();
+        final List<String> list = List.of("input1", "input2");
+        Supplier<List<String>> supplier = () -> {
+            invoked.set(true);
+            return list;
+        };
+        EmbeddingsInput input = new EmbeddingsInput(supplier, null);
+        // Ensure we don't invoke the supplier until we call getInputs()
+        assertThat(invoked.get(), is(false));
+
+        assertThat(input.getInputs(), is(list));
+        assertThat(invoked.get(), is(true));
+    }
+
+    public void testCallingGetInputsTwice_throws() {
+        Supplier<List<String>> supplier = () -> List.of("input");
+        EmbeddingsInput input = new EmbeddingsInput(supplier, null);
+        input.getInputs();
+        var exception = expectThrows(AssertionError.class, input::getInputs);
+        assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice"));
+    }
+}

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

@@ -14,7 +14,6 @@ import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
@@ -143,7 +142,7 @@ public class HttpRequestSenderTests extends ESTestCase {
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             sender.send(
                 OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool),
-                new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null),
+                new EmbeddingsInput(List.of("abc"), null),
                 null,
                 listener
             );

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java

@@ -17,7 +17,7 @@ import java.util.List;
 
 public class InferenceInputsTests extends ESTestCase {
     public void testCastToSucceeds() {
-        InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull(), false);
+        InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull());
         assertThat(inputs.castTo(EmbeddingsInput.class), Matchers.instanceOf(EmbeddingsInput.class));
 
         var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null);
@@ -29,7 +29,7 @@ public class InferenceInputsTests extends ESTestCase {
     }
 
     public void testCastToFails() {
-        InferenceInputs inputs = new EmbeddingsInput(List.of(), null, false);
+        InferenceInputs inputs = new EmbeddingsInput(List.of(), null);
         var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class));
         assertThat(
             exception.getMessage(),

+ 5 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java

@@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.Scheduler;
@@ -62,7 +61,7 @@ public class RequestTaskTests extends ESTestCase {
 
         var requestTask = new RequestTask(
             OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool),
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             TimeValue.timeValueMillis(1),
             mockThreadPool,
             listener
@@ -82,7 +81,7 @@ public class RequestTaskTests extends ESTestCase {
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         var requestTask = new RequestTask(
             OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool),
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             TimeValue.timeValueMillis(1),
             threadPool,
             listener
@@ -106,7 +105,7 @@ public class RequestTaskTests extends ESTestCase {
 
         var requestTask = new RequestTask(
             OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool),
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             TimeValue.timeValueMillis(1),
             threadPool,
             listener
@@ -135,7 +134,7 @@ public class RequestTaskTests extends ESTestCase {
 
         var requestTask = new RequestTask(
             OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool),
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             TimeValue.timeValueMillis(1),
             threadPool,
             listener
@@ -162,7 +161,7 @@ public class RequestTaskTests extends ESTestCase {
 
         var requestTask = new RequestTask(
             OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool),
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             TimeValue.timeValueMillis(1),
             mockThreadPool,
             listener

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -20,7 +21,6 @@ import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@@ -131,7 +131,7 @@ public class SenderServiceTests extends ESTestCase {
         @Override
         protected void doChunkedInfer(
             Model model,
-            EmbeddingsInput inputs,
+            List<ChunkInferenceInput> inputs,
             Map<String, Object> taskSettings,
             InputType inputType,
             TimeValue timeout,

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java

@@ -122,7 +122,7 @@ public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         assertThrows(IllegalArgumentException.class, () -> {
             action.execute(
-                new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, InputType.INGEST),
+                new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java

@@ -74,7 +74,7 @@ public class AmazonBedrockActionCreatorTests extends ESTestCase {
             var action = creator.create(model, Map.of());
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -112,7 +112,7 @@ public class AmazonBedrockActionCreatorTests extends ESTestCase {
             var action = creator.create(model, Map.of());
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomWithNull();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
             var result = listener.actionGet(TIMEOUT);
 
             assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
@@ -145,7 +145,7 @@ public class AmazonBedrockActionCreatorTests extends ESTestCase {
             );
             var action = creator.create(model, Map.of());
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
             var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
             assertThat(sender.sendCount(), is(1));

+ 4 - 5
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java

@@ -12,7 +12,6 @@ import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
@@ -83,10 +82,10 @@ public class AmazonBedrockMockRequestSender implements Sender {
         ActionListener<InferenceServiceResults> listener
     ) {
         sendCounter++;
-        if (inferenceInputs instanceof EmbeddingsInput docsInput) {
-            inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs()));
-            if (docsInput.getInputType() != null) {
-                inputTypes.add(docsInput.getInputType());
+        if (inferenceInputs instanceof EmbeddingsInput embeddingsInput) {
+            inputs.add(embeddingsInput.getInputs());
+            if (embeddingsInput.getInputType() != null) {
+                inputTypes.add(embeddingsInput.getInputType());
             }
         } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) {
             inputs.add(chatCompletionInput.getInputs());

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock.client;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -112,7 +111,7 @@ public class AmazonBedrockRequestSenderTests extends ESTestCase {
                 threadPool,
                 new TimeValue(30, TimeUnit.SECONDS)
             );
-            sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener);
+            sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener);
 
             var result = listener.actionGet(TIMEOUT);
             assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F }))));

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java

@@ -115,7 +115,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
             var action = creator.create(model, Map.of());
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomSearchAndIngestWithNull();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 

+ 6 - 10
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java

@@ -119,7 +119,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -170,7 +170,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -222,7 +222,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var failureCauseMessage = "Required [data]";
             var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
@@ -296,7 +296,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -373,7 +373,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -433,11 +433,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(
-                new EmbeddingsInput(List.of("super long input"), null, inputType),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
+            action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 

+ 5 - 26
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java

@@ -14,7 +14,6 @@ import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
@@ -117,11 +116,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomWithNull();
-            action.execute(
-                new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -149,11 +144,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         var inputType = InputTypeTests.randomWithNull();
-        action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-            InferenceAction.Request.DEFAULT_TIMEOUT,
-            listener
-        );
+        action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -174,11 +165,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         var inputType = InputTypeTests.randomWithNull();
-        action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-            InferenceAction.Request.DEFAULT_TIMEOUT,
-            listener
-        );
+        action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -199,11 +186,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         var inputType = InputTypeTests.randomWithNull();
-        action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-            InferenceAction.Request.DEFAULT_TIMEOUT,
-            listener
-        );
+        action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -218,11 +201,7 @@ public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         var inputType = InputTypeTests.randomWithNull();
-        action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-            InferenceAction.Request.DEFAULT_TIMEOUT,
-            listener
-        );
+        action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java

@@ -119,7 +119,7 @@ public class CohereActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 

+ 7 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java

@@ -124,7 +124,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             InputType inputType = InputTypeTests.randomWithNull();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -207,7 +207,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
             );
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -262,7 +262,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
         var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -282,7 +282,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
         var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -302,7 +302,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
         var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -316,7 +316,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
         var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -330,7 +330,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase {
         var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java

@@ -77,7 +77,7 @@ public class CustomRequestManagerTests extends ESTestCase {
 
         var listener = new PlainActionFuture<InferenceServiceResults>();
         var manager = CustomRequestManager.of(model, threadPool);
-        manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener);
+        manager.execute(new EmbeddingsInput(List.of("abc", "123"), null), mock(RequestSender.class), () -> false, listener);
 
         var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(30)));
 

+ 3 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java

@@ -82,10 +82,7 @@ public class CustomRequestTests extends ESTestCase {
         );
 
         var request = new CustomRequest(
-            EmbeddingParameters.of(
-                new EmbeddingsInput(List.of("abc", "123"), null, null),
-                model.getServiceSettings().getInputTypeTranslator()
-            ),
+            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc", "123"), null), model.getServiceSettings().getInputTypeTranslator()),
             model
         );
         var httpRequest = request.createHttpRequest();
@@ -146,7 +143,7 @@ public class CustomRequestTests extends ESTestCase {
 
         var request = new CustomRequest(
             EmbeddingParameters.of(
-                new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST),
+                new EmbeddingsInput(List.of("abc", "123"), InputType.INGEST),
                 model.getServiceSettings().getInputTypeTranslator()
             ),
             model
@@ -207,7 +204,7 @@ public class CustomRequestTests extends ESTestCase {
 
         var request = new CustomRequest(
             EmbeddingParameters.of(
-                new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH),
+                new EmbeddingsInput(List.of("abc", "123"), InputType.SEARCH),
                 model.getServiceSettings().getInputTypeTranslator()
             ),
             model

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java

@@ -21,7 +21,7 @@ public class EmbeddingParametersTests extends ESTestCase {
 
     public void testTaskTypeParameters_UsesDefaultValue() {
         var parameters = EmbeddingParameters.of(
-            new EmbeddingsInput(List.of("input"), null, InputType.INGEST),
+            new EmbeddingsInput(List.of("input"), InputType.INGEST),
             new InputTypeTranslator(Map.of(), "default")
         );
 
@@ -30,7 +30,7 @@ public class EmbeddingParametersTests extends ESTestCase {
 
     public void testTaskTypeParameters_UsesMappedValue() {
         var parameters = EmbeddingParameters.of(
-            new EmbeddingsInput(List.of("input"), null, InputType.INGEST),
+            new EmbeddingsInput(List.of("input"), InputType.INGEST),
             new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default")
         );
 

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java

@@ -64,7 +64,7 @@ public class CustomResponseEntityTests extends ESTestCase {
             new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")
         );
         var request = new CustomRequest(
-            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()),
+            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()),
             model
         );
         InferenceServiceResults results = CustomResponseEntity.fromResponse(
@@ -115,7 +115,7 @@ public class CustomResponseEntityTests extends ESTestCase {
             )
         );
         var request = new CustomRequest(
-            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()),
+            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()),
             model
 
         );

+ 8 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java

@@ -102,7 +102,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
+                new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -163,7 +163,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
+                new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -291,7 +291,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED),
+                new EmbeddingsInput(List.of("hello world", "second text"), InputType.UNSPECIFIED),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -347,7 +347,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH),
+                new EmbeddingsInput(List.of("search query"), InputType.SEARCH),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -402,7 +402,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
+                new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -443,7 +443,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
             var action = actionCreator.create(model);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of(), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -492,7 +492,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
+                new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -556,7 +556,7 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
+                new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );

+ 4 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java

@@ -108,7 +108,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomWithNull();
-            action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -163,7 +163,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -187,7 +187,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -205,7 +205,7 @@ public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java

@@ -75,7 +75,7 @@ public class GoogleVertexAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -99,7 +99,7 @@ public class GoogleVertexAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -117,7 +117,7 @@ public class GoogleVertexAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java

@@ -71,7 +71,7 @@ public class GoogleVertexAiRerankActionTests extends ESTestCase {
         var action = createAction(getUrl(webServer), "projectId", sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -91,7 +91,7 @@ public class GoogleVertexAiRerankActionTests extends ESTestCase {
         var action = createAction(getUrl(webServer), "projectId", sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
@@ -105,7 +105,7 @@ public class GoogleVertexAiRerankActionTests extends ESTestCase {
         var action = createAction(getUrl(webServer), "projectId", sender);
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+        action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 

+ 6 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java

@@ -101,7 +101,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -172,7 +172,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -224,7 +224,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -286,7 +286,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -409,7 +409,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -471,7 +471,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java

@@ -63,7 +63,7 @@ public class HuggingFaceActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -87,7 +87,7 @@ public class HuggingFaceActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -108,7 +108,7 @@ public class HuggingFaceActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );

+ 4 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java

@@ -114,7 +114,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -144,7 +144,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -173,7 +173,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -197,7 +197,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );

+ 7 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java

@@ -109,7 +109,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -166,7 +166,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -222,7 +222,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -285,7 +285,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -625,7 +625,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -712,7 +712,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -784,7 +784,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );

+ 6 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java

@@ -14,7 +14,6 @@ import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockResponse;
@@ -115,7 +114,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(
-                new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+                new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
                 InferenceAction.Request.DEFAULT_TIMEOUT,
                 listener
             );
@@ -155,7 +154,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -179,7 +178,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -203,7 +202,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -221,7 +220,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -239,7 +238,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );

+ 1 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java

@@ -11,7 +11,6 @@ import org.apache.http.HttpHeaders;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
@@ -111,11 +110,7 @@ public class VoyageAIActionCreatorTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomSearchAndIngestWithNull();
-            action.execute(
-                new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 

+ 7 - 20
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java

@@ -14,7 +14,6 @@ import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
@@ -121,11 +120,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomSearchAndIngestWithNull();
-            action.execute(
-                new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -222,11 +217,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomSearchAndIngestWithNull();
-            action.execute(
-                new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -323,11 +314,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             var inputType = InputTypeTests.randomSearchAndIngestWithNull();
-            action.execute(
-                new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
+            action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -396,7 +383,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -420,7 +407,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -438,7 +425,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
 
         PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
         action.execute(
-            new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()),
+            new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()),
             InferenceAction.Request.DEFAULT_TIMEOUT,
             listener
         );
@@ -461,7 +448,7 @@ public class VoyageAIEmbeddingsActionTests extends ESTestCase {
             threadPool,
             model,
             EMBEDDINGS_HANDLER,
-            (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model),
+            (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model),
             EmbeddingsInput.class
         );