Browse Source

[ML] Do not convert input Strings to ChunkInferenceInput unless necessary (#134945) (#135024)

The SenderService.infer() method was converting the input variable from
a List<String> into a List<ChunkInferenceInput>, but then when that list
was passed into SenderService.createInput() it was immediately
converted back into a List<String>. To avoid unnecessary work, allow the
EmbeddingsInput constructor to convert the list if necessary.

(cherry picked from commit 8d2c4c312f073d857be2ef4417e01cbf17f1b00d)

Co-authored-by: Donal Evans <donal.evans@elastic.co>
Jonathan Buttner 3 weeks ago
parent
commit
0cda0f6b3a

+ 9 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

@@ -30,10 +30,6 @@ public class EmbeddingsInput extends InferenceInputs {
     private final Supplier<List<ChunkInferenceInput>> listSupplier;
     private final InputType inputType;
 
-    public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
-        this(input, inputType, false);
-    }
-
     public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
         super(false);
         this.listSupplier = Objects.requireNonNull(inputSupplier);
@@ -41,7 +37,15 @@ public class EmbeddingsInput extends InferenceInputs {
     }
 
     public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
-        this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false);
+        this(input, chunkingSettings, inputType, false);
+    }
+
+    public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
+        this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
+    }
+
+    public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
+        this(input, inputType, false);
     }
 
     public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {

+ 5 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

@@ -71,31 +71,29 @@ public abstract class SenderService implements InferenceService {
         ActionListener<InferenceServiceResults> listener
     ) {
         init();
-        var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
-        var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
+        var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
         doInfer(model, inferenceInput, taskSettings, timeout, listener);
     }
 
     private static InferenceInputs createInput(
         SenderService service,
         Model model,
-        List<ChunkInferenceInput> input,
+        List<String> input,
         InputType inputType,
         @Nullable String query,
         @Nullable Boolean returnDocuments,
         @Nullable Integer topN,
         boolean stream
     ) {
-        List<String> textInput = ChunkInferenceInput.inputs(input);
         return switch (model.getTaskType()) {
-            case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream);
+            case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
             case RERANK -> {
                 ValidationException validationException = new ValidationException();
                 service.validateRerankParameters(returnDocuments, topN, validationException);
                 if (validationException.validationErrors().isEmpty() == false) {
                     throw validationException;
                 }
-                yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream);
+                yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
             }
             case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
                 ValidationException validationException = new ValidationException();
@@ -103,7 +101,7 @@ public abstract class SenderService implements InferenceService {
                 if (validationException.validationErrors().isEmpty() == false) {
                     throw validationException;
                 }
-                yield new EmbeddingsInput(input, inputType, stream);
+                yield new EmbeddingsInput(input, null, inputType, stream);
             }
             default -> throw new ElasticsearchStatusException(
                 Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),