Browse Source

[ML] Simplify the Inference Ingest Processor configuration (#100205)

Adds a `input_ouput` option the removes the need for a `field_map` and/or
target fields. Multiple inputs can be specified in `input_output`
David Kyle 2 years ago
parent
commit
b055204b43
23 changed files with 1054 additions and 78 deletions
  1. 5 0
      docs/changelog/100205.yaml
  2. 52 0
      docs/reference/ingest/processors/inference.asciidoc
  3. 46 3
      server/src/main/java/org/elasticsearch/inference/InferenceResults.java
  4. 13 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java
  5. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java
  6. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java
  7. 12 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java
  8. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java
  9. 12 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java
  10. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java
  11. 13 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java
  12. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java
  13. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java
  14. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java
  15. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java
  16. 0 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java
  17. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java
  18. 0 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java
  19. 0 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java
  20. 133 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestInputConfigIT.java
  21. 269 37
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  22. 360 20
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java
  23. 85 12
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

+ 5 - 0
docs/changelog/100205.yaml

@@ -0,0 +1,5 @@
+pr: 100205
+summary: Simplify the Inference Ingest Processor configuration
+area: Machine Learning
+type: enhancement
+issues: []

+ 52 - 0
docs/reference/ingest/processors/inference.asciidoc

@@ -17,12 +17,64 @@ ingested in the pipeline.
 |======
 | Name                                  | Required  | Default                        | Description
 | `model_id` .                          | yes       | -                              | (String) The ID or alias for the trained model, or the ID of the deployment.
+| `input_output`                        | no        | (List) Input fields for inference and output (destination) fields for the inference results. This options is incompatible with the `target_field` and `field_map` options.
 | `target_field`                        | no        | `ml.inference.<processor_tag>` | (String) Field added to incoming documents to contain results objects.
 | `field_map`                           | no        | If defined the model's default field map | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration.
 | `inference_config`                    | no        | The default settings defined in the model  | (Object) Contains the inference type and its options.
 include::common-options.asciidoc[]
 |======
 
+[discrete]
+[[inference-input-output-example]]
+==== Configuring input and output fields
+Select the `content` field for inference and write the result to `content_embedding`.
+
+[source,js]
+--------------------------------------------------
+{
+  "inference": {
+    "model_id": "model_deployment_for_inference",
+    "input_output": [
+        {
+            "input_field": "content",
+            "output_field": "content_embedding"
+        }
+    ]
+  }
+}
+--------------------------------------------------
+// NOTCONSOLE
+
+==== Configuring multiple inputs
+The `content` and `title` fields will be read from the incoming document
+and sent to the model for the inference. The inference output is written
+to `content_embedding` and `title_embedding` respectively.
+[source,js]
+--------------------------------------------------
+{
+  "inference": {
+    "model_id": "model_deployment_for_inference",
+    "input_output": [
+        {
+            "input_field": "content",
+            "output_field": "content_embedding"
+        },
+        {
+            "input_field": "title",
+            "output_field": "title_embedding"
+        }
+    ]
+  }
+}
+--------------------------------------------------
+// NOTCONSOLE
+
+Selecting the input fields with `input_output` is incompatible with
+the `target_field` and `field_map` options.
+
+Data frame analytics models must use the `target_field` to specify the
+root location results are written to and optionally a `field_map` to map
+field names in the input document to the model input fields.
 
 [source,js]
 --------------------------------------------------

+ 46 - 3
server/src/main/java/org/elasticsearch/inference/InferenceResults.java

@@ -9,6 +9,7 @@
 package org.elasticsearch.inference;
 
 import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.xcontent.ToXContentFragment;
 
@@ -24,16 +25,58 @@ public interface InferenceResults extends NamedWriteable, ToXContentFragment {
         Objects.requireNonNull(resultField, "resultField");
         Map<String, Object> resultMap = results.asMap();
         resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
-        if (ingestDocument.hasField(resultField)) {
-            ingestDocument.appendFieldValue(resultField, resultMap);
+        setOrAppendValue(resultField, resultMap, ingestDocument);
+    }
+
+    static void writeResultToField(
+        InferenceResults results,
+        IngestDocument ingestDocument,
+        @Nullable String basePath,
+        String outputField,
+        String modelId,
+        boolean includeModelId
+    ) {
+        Objects.requireNonNull(results, "results");
+        Objects.requireNonNull(ingestDocument, "ingestDocument");
+        Objects.requireNonNull(outputField, "outputField");
+        Map<String, Object> resultMap = results.asMap(outputField);
+        if (includeModelId) {
+            resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
+        }
+        if (basePath == null) {
+            // insert the results into the root of the document
+            for (var entry : resultMap.entrySet()) {
+                setOrAppendValue(entry.getKey(), entry.getValue(), ingestDocument);
+            }
         } else {
-            ingestDocument.setFieldValue(resultField, resultMap);
+            for (var entry : resultMap.entrySet()) {
+                setOrAppendValue(basePath + "." + entry.getKey(), entry.getValue(), ingestDocument);
+            }
+        }
+    }
+
+    private static void setOrAppendValue(String path, Object value, IngestDocument ingestDocument) {
+        if (ingestDocument.hasField(path)) {
+            ingestDocument.appendFieldValue(path, value);
+        } else {
+            ingestDocument.setFieldValue(path, value);
         }
     }
 
     String getResultsField();
 
+    /**
+     * Convert to a map
+     * @return Map representation of the InferenceResult
+     */
     Map<String, Object> asMap();
 
+    /**
+     * Convert to a map placing the inference result in {@code outputField}
+     * @param outputField Write the inference result to this field
+     * @return Map representation of the InferenceResult
+     */
+    Map<String, Object> asMap(String outputField);
+
     Object predictedValue();
 }

+ 13 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

@@ -220,6 +220,19 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
     public Map<String, Object> asMap() {
         Map<String, Object> map = new LinkedHashMap<>();
         map.put(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
+        addSupportingFieldsToMap(map);
+        return map;
+    }
+
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        Map<String, Object> map = new LinkedHashMap<>();
+        map.put(outputField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
+        addSupportingFieldsToMap(map);
+        return map;
+    }
+
+    private void addSupportingFieldsToMap(Map<String, Object> map) {
         if (topClasses.isEmpty() == false) {
             map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
         }
@@ -235,7 +248,6 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
                 featureImportance.stream().map(ClassificationFeatureImportance::toMap).collect(Collectors.toList())
             );
         }
-        return map;
     }
 
     @Override

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java

@@ -74,6 +74,12 @@ public class ErrorInferenceResults implements InferenceResults {
         return asMap;
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        // errors do not have a result
+        return asMap();
+    }
+
     @Override
     public String toString() {
         return Strings.toString(this);

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java

@@ -55,6 +55,13 @@ public class FillMaskResults extends NlpClassificationInferenceResults {
         map.put(resultsField + "_sequence", predictedSequence);
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        var map = super.asMap(outputField);
+        map.put(outputField + "_sequence", predictedSequence);
+        return map;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME;

+ 12 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java

@@ -59,10 +59,21 @@ abstract class NlpInferenceResults implements InferenceResults {
     public final Map<String, Object> asMap() {
         Map<String, Object> map = new LinkedHashMap<>();
         addMapFields(map);
+        addSupportingFieldsToMap(map);
+        return map;
+    }
+
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        Map<String, Object> map = new LinkedHashMap<>();
+        addSupportingFieldsToMap(map);
+        return map;
+    }
+
+    private void addSupportingFieldsToMap(Map<String, Object> map) {
         if (isTruncated) {
             map.put("is_truncated", isTruncated);
         }
-        return map;
     }
 
     @Override

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java

@@ -65,6 +65,13 @@ public class PyTorchPassThroughResults extends NlpInferenceResults {
         map.put(resultsField, inference);
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        var map = super.asMap(outputField);
+        map.put(outputField, inference);
+        return map;
+    }
+
     @Override
     public Object predictedValue() {
         throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");

+ 12 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java

@@ -121,6 +121,18 @@ public class QuestionAnsweringInferenceResults extends NlpInferenceResults {
     @Override
     void addMapFields(Map<String, Object> map) {
         map.put(resultsField, answer);
+        addSupportingFieldsToMap(map);
+    }
+
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        var map = super.asMap(outputField);
+        map.put(outputField, answer);
+        addSupportingFieldsToMap(map);
+        return map;
+    }
+
+    private void addSupportingFieldsToMap(Map<String, Object> map) {
         map.put(START_OFFSET.getPreferredName(), startOffset);
         map.put(END_OFFSET.getPreferredName(), endOffset);
         if (topClasses.isEmpty() == false) {

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

@@ -69,6 +69,11 @@ public class RawInferenceResults implements InferenceResults {
         throw new UnsupportedOperationException("[raw] does not support map conversion");
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        throw new UnsupportedOperationException("[raw] does not support map conversion");
+    }
+
     @Override
     public Object predictedValue() {
         return null;

+ 13 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

@@ -121,11 +121,23 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     @Override
     public Map<String, Object> asMap() {
         Map<String, Object> map = new LinkedHashMap<>();
+        addSupportingFieldsToMap(map);
         map.put(resultsField, value());
+        return map;
+    }
+
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        Map<String, Object> map = new LinkedHashMap<>();
+        addSupportingFieldsToMap(map);
+        map.put(outputField, value());
+        return map;
+    }
+
+    private void addSupportingFieldsToMap(Map<String, Object> map) {
         if (featureImportance.isEmpty() == false) {
             map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
         }
-        return map;
     }
 
     @Override

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java

@@ -72,6 +72,13 @@ public class TextEmbeddingResults extends NlpInferenceResults {
         map.put(resultsField, inference);
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        var map = super.asMap(outputField);
+        map.put(outputField, inference);
+        return map;
+    }
+
     @Override
     public Object predictedValue() {
         throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java

@@ -119,4 +119,11 @@ public class TextExpansionResults extends NlpInferenceResults {
     void addMapFields(Map<String, Object> map) {
         map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
     }
+
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        var map = super.asMap(outputField);
+        map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
+        return map;
+    }
 }

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java

@@ -68,6 +68,13 @@ public class TextSimilarityInferenceResults extends NlpInferenceResults {
         map.put(resultsField, score);
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        var map = super.asMap(outputField);
+        map.put(outputField, score);
+        return map;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME;

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java

@@ -75,6 +75,12 @@ public class WarningInferenceResults implements InferenceResults {
         return asMap;
     }
 
+    @Override
+    public Map<String, Object> asMap(String outputField) {
+        // warnings do not have a result
+        return asMap();
+    }
+
     @Override
     public Object predictedValue() {
         return null;

+ 0 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java

@@ -23,7 +23,6 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
 
     public static final ParseField NAME = new ParseField("classification");
 
-    public static final ParseField RESULTS_FIELD = new ParseField("results_field");
     public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
     public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
     public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
+import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xpack.core.ml.MlConfigVersion;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
@@ -15,6 +16,7 @@ public interface InferenceConfig extends NamedXContentObject, VersionedNamedWrit
 
     String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
     String DEFAULT_RESULTS_FIELD = "predicted_value";
+    ParseField RESULTS_FIELD = new ParseField("results_field");
 
     boolean isTargetTypeSupported(TargetType targetType);
 

+ 0 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java

@@ -15,7 +15,6 @@ public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParse
     ParseField VOCABULARY = new ParseField("vocabulary");
     ParseField TOKENIZATION = new ParseField("tokenization");
     ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
-    ParseField RESULTS_FIELD = new ParseField("results_field");
     ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
 
     MlConfigVersion MINIMUM_NLP_SUPPORTED_VERSION = MlConfigVersion.V_8_0_0;

+ 0 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java

@@ -24,7 +24,6 @@ public class RegressionConfig implements LenientlyParsedInferenceConfig, Strictl
     public static final ParseField NAME = new ParseField("regression");
     private static final MlConfigVersion MIN_SUPPORTED_VERSION = MlConfigVersion.V_7_6_0;
     private static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersions.V_7_6_0;
-    public static final ParseField RESULTS_FIELD = new ParseField("results_field");
     public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
 
     public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);

+ 133 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestInputConfigIT.java

@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.integration;
+
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.hasSize;
+
+public class InferenceIngestInputConfigIT extends PyTorchModelRestTestCase {
+
+    @SuppressWarnings("unchecked")
+    public void testIngestWithInputFields() throws IOException {
+        String modelId = "test_ingest_with_input_fields";
+        createPassThroughModel(modelId);
+        putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        String inputOutput = """
+            [
+              {
+                "input_field": "body",
+                "output_field": "body_tokens"
+              }
+            ]
+            """;
+        String docs = """
+            [
+                {
+                  "_source": {
+                    "body": "these are"
+                  }
+                },
+                {
+                  "_source": {
+                    "body": "my words"
+                  }
+                }
+              ]
+            """;
+        var simulateResponse = simulatePipeline(pipelineDefinition(modelId, inputOutput), docs);
+        var responseMap = entityAsMap(simulateResponse);
+        var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
+        assertThat(simulatedDocs, hasSize(2));
+        assertNotNull(MapHelper.dig("doc._source.body_tokens", simulatedDocs.get(0)));
+        assertNotNull(MapHelper.dig("doc._source.body_tokens", simulatedDocs.get(1)));
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testIngestWithMultipleInputFields() throws IOException {
+        String modelId = "test_ingest_with_multiple_input_fields";
+        createPassThroughModel(modelId);
+        putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        String inputOutput = """
+            [
+              {
+                "input_field": "title",
+                "output_field": "ml.body_tokens"
+              },
+              {
+                "input_field": "body",
+                "output_field": "ml.title_tokens"
+              }
+            ]
+            """;
+
+        String docs = """
+            [
+                {
+                  "_source": {
+                    "title": "my",
+                    "body": "these are"
+                  }
+                },
+                {
+                  "_source": {
+                    "title": "are",
+                    "body": "my words"
+                  }
+                }
+            ]
+            """;
+        var simulateResponse = simulatePipeline(pipelineDefinition(modelId, inputOutput), docs);
+        var responseMap = entityAsMap(simulateResponse);
+        var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
+        assertThat(simulatedDocs, hasSize(2));
+        assertNotNull(MapHelper.dig("doc._source.ml.title_tokens", simulatedDocs.get(0)));
+        assertNotNull(MapHelper.dig("doc._source.ml.body_tokens", simulatedDocs.get(0)));
+        assertNotNull(MapHelper.dig("doc._source.ml.title_tokens", simulatedDocs.get(1)));
+        assertNotNull(MapHelper.dig("doc._source.ml.body_tokens", simulatedDocs.get(1)));
+    }
+
+    private static String pipelineDefinition(String modelId, String inputOutput) {
+        return Strings.format("""
+            {
+              "processors": [
+                {
+                  "inference": {
+                    "model_id": "%s",
+                    "input_output": %s
+                  }
+                }
+              ]
+            }""", modelId, inputOutput);
+    }
+
+    private Response simulatePipeline(String pipelineDef, String docs) throws IOException {
+        String simulate = Strings.format("""
+            {
+              "pipeline": %s,
+              "docs": %s
+            }""", pipelineDef, docs);
+
+        Request request = new Request("POST", "_ingest/pipeline/_simulate?error_trace=true");
+        request.setJsonEntity(simulate);
+        return client().performRequest(request);
+    }
+}

+ 269 - 37
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -17,6 +17,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.ingest.AbstractProcessor;
 import org.elasticsearch.ingest.ConfigurationUtils;
@@ -57,15 +58,17 @@ import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 
-import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD;
+import static org.elasticsearch.ingest.ConfigurationUtils.newConfigurationException;
 import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -82,23 +85,57 @@ public class InferenceProcessor extends AbstractProcessor {
     );
 
     public static final String TYPE = "inference";
+    public static final String MODEL_ID = "model_id";
     public static final String INFERENCE_CONFIG = "inference_config";
+
+    // target field style mappings
     public static final String TARGET_FIELD = "target_field";
     public static final String FIELD_MAPPINGS = "field_mappings";
     public static final String FIELD_MAP = "field_map";
     private static final String DEFAULT_TARGET_FIELD = "ml.inference";
 
+    // input field config
+    public static final String INPUT_OUTPUT = "input_output";
+    public static final String INPUT_FIELD = "input_field";
+    public static final String OUTPUT_FIELD = "output_field";
+
+    public static InferenceProcessor fromInputFieldConfiguration(
+        Client client,
+        InferenceAuditor auditor,
+        String tag,
+        String description,
+        String modelId,
+        InferenceConfigUpdate inferenceConfig,
+        List<Factory.InputConfig> inputs
+    ) {
+        return new InferenceProcessor(client, auditor, tag, description, null, modelId, inferenceConfig, null, inputs, true);
+    }
+
+    public static InferenceProcessor fromTargetFieldConfiguration(
+        Client client,
+        InferenceAuditor auditor,
+        String tag,
+        String description,
+        String targetField,
+        String modelId,
+        InferenceConfigUpdate inferenceConfig,
+        Map<String, String> fieldMap
+    ) {
+        return new InferenceProcessor(client, auditor, tag, description, targetField, modelId, inferenceConfig, fieldMap, null, false);
+    }
+
     private final Client client;
     private final String modelId;
-
     private final String targetField;
     private final InferenceConfigUpdate inferenceConfig;
     private final Map<String, String> fieldMap;
     private final InferenceAuditor auditor;
     private volatile boolean previouslyLicensed;
     private final AtomicBoolean shouldAudit = new AtomicBoolean(true);
+    private final List<Factory.InputConfig> inputs;
+    private final boolean configuredWithInputsFields;
 
-    public InferenceProcessor(
+    private InferenceProcessor(
         Client client,
         InferenceAuditor auditor,
         String tag,
@@ -106,15 +143,26 @@ public class InferenceProcessor extends AbstractProcessor {
         String targetField,
         String modelId,
         InferenceConfigUpdate inferenceConfig,
-        Map<String, String> fieldMap
+        Map<String, String> fieldMap,
+        List<Factory.InputConfig> inputs,
+        boolean configuredWithInputsFields
     ) {
         super(tag, description);
+        this.configuredWithInputsFields = configuredWithInputsFields;
         this.client = ExceptionsHelper.requireNonNull(client, "client");
-        this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
         this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor");
-        this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID_RESULTS_FIELD);
+        this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
         this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
-        this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP);
+
+        if (configuredWithInputsFields) {
+            this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT_OUTPUT);
+            this.targetField = null;
+            this.fieldMap = null;
+        } else {
+            this.inputs = null;
+            this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
+            this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP);
+        }
     }
 
     public String getModelId() {
@@ -123,11 +171,20 @@ public class InferenceProcessor extends AbstractProcessor {
 
     @Override
     public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
+
+        InferModelAction.Request request;
+        try {
+            request = buildRequest(ingestDocument);
+        } catch (ElasticsearchStatusException e) {
+            handler.accept(ingestDocument, e);
+            return;
+        }
+
         executeAsyncWithOrigin(
             client,
             ML_ORIGIN,
             InferModelAction.INSTANCE,
-            this.buildRequest(ingestDocument),
+            request,
             ActionListener.wrap(r -> handleResponse(r, ingestDocument, handler), e -> handler.accept(ingestDocument, e))
         );
     }
@@ -153,8 +210,21 @@ public class InferenceProcessor extends AbstractProcessor {
         if (ingestDocument.getIngestMetadata().isEmpty() == false) {
             fields.put(INGEST_KEY, ingestDocument.getIngestMetadata());
         }
-        LocalModel.mapFieldsIfNecessary(fields, fieldMap);
-        return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed);
+
+        if (configuredWithInputsFields) {
+            List<String> requestInputs = new ArrayList<>();
+            for (var inputFields : inputs) {
+                var lookup = (String) fields.get(inputFields.inputField);
+                if (lookup == null) {
+                    lookup = ""; // need to send a non-null request to the same number of results back
+                }
+                requestInputs.add(lookup);
+            }
+            return InferModelAction.Request.forTextInput(modelId, inferenceConfig, requestInputs);
+        } else {
+            LocalModel.mapFieldsIfNecessary(fields, fieldMap);
+            return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed);
+        }
     }
 
     void auditWarningAboutLicenseIfNecessary() {
@@ -171,13 +241,42 @@ public class InferenceProcessor extends AbstractProcessor {
         if (response.getInferenceResults().isEmpty()) {
             throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
         }
-        assert response.getInferenceResults().size() == 1;
-        InferenceResults.writeResult(
-            response.getInferenceResults().get(0),
-            ingestDocument,
-            targetField,
-            response.getId() != null ? response.getId() : modelId
-        );
+
+        // TODO
+        // The field where the model Id is written to.
+        // If multiple inference processors are in the same pipeline, it is wise to tag them
+        // The tag will keep default value entries from stepping on each other
+        // String modelIdField = tag == null ? MODEL_ID_RESULTS_FIELD : MODEL_ID_RESULTS_FIELD + "." + tag;
+
+        if (configuredWithInputsFields) {
+            if (response.getInferenceResults().size() != inputs.size()) {
+                throw new ElasticsearchStatusException(
+                    "number of results [{}] does not match the number of inputs [{}]",
+                    RestStatus.INTERNAL_SERVER_ERROR,
+                    response.getInferenceResults().size(),
+                    inputs.size()
+                );
+            }
+
+            for (int i = 0; i < inputs.size(); i++) {
+                InferenceResults.writeResultToField(
+                    response.getInferenceResults().get(i),
+                    ingestDocument,
+                    inputs.get(i).outputBasePath(),
+                    inputs.get(i).outputField,
+                    response.getId() != null ? response.getId() : modelId,
+                    i == 0
+                );
+            }
+        } else {
+            assert response.getInferenceResults().size() == 1;
+            InferenceResults.writeResult(
+                response.getInferenceResults().get(0),
+                ingestDocument,
+                targetField,
+                response.getId() != null ? response.getId() : modelId
+            );
+        }
     }
 
     @Override
@@ -190,6 +289,30 @@ public class InferenceProcessor extends AbstractProcessor {
         return TYPE;
     }
 
+    boolean isConfiguredWithInputsFields() {
+        return configuredWithInputsFields;
+    }
+
+    public List<Factory.InputConfig> getInputs() {
+        return inputs;
+    }
+
+    Map<String, String> getFieldMap() {
+        return fieldMap;
+    }
+
+    String getTargetField() {
+        return targetField;
+    }
+
+    InferenceConfigUpdate getInferenceConfig() {
+        return inferenceConfig;
+    }
+
+    InferenceAuditor getAuditor() {
+        return auditor;
+    }
+
     public static final class Factory implements Processor.Factory, Consumer<ClusterState> {
 
         private static final Logger logger = LogManager.getLogger(Factory.class);
@@ -225,7 +348,6 @@ public class InferenceProcessor extends AbstractProcessor {
             String description,
             Map<String, Object> config
         ) {
-
             if (this.maxIngestProcessors <= currentInferenceProcessors) {
                 throw new ElasticsearchStatusException(
                     "Max number of inference processors reached, total inference processors [{}]. "
@@ -237,37 +359,91 @@ public class InferenceProcessor extends AbstractProcessor {
                 );
             }
 
-            String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID_RESULTS_FIELD);
-            String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag;
-            // If multiple inference processors are in the same pipeline, it is wise to tag them
-            // The tag will keep default value entries from stepping on each other
-            String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField);
-            Map<String, String> fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAP);
-            if (fieldMap == null) {
-                fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
-                // TODO Remove in 9?.x
-                if (fieldMap != null) {
-                    LoggingDeprecationHandler.INSTANCE.logRenamedField(null, () -> null, FIELD_MAPPINGS, FIELD_MAP);
-                }
-            }
-            if (fieldMap == null) {
-                fieldMap = Collections.emptyMap();
-            }
+            String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
 
             InferenceConfigUpdate inferenceConfigUpdate;
             Map<String, Object> inferenceConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, INFERENCE_CONFIG);
             if (inferenceConfigMap == null) {
                 if (minNodeVersion.before(EmptyConfigUpdate.minimumSupportedVersion())) {
                     // an inference config is required when the empty update is not supported
-                    throw ConfigurationUtils.newConfigurationException(TYPE, tag, INFERENCE_CONFIG, "required property is missing");
+                    throw newConfigurationException(TYPE, tag, INFERENCE_CONFIG, "required property is missing");
                 }
-
                 inferenceConfigUpdate = new EmptyConfigUpdate();
             } else {
                 inferenceConfigUpdate = inferenceConfigUpdateFromMap(inferenceConfigMap);
             }
 
-            return new InferenceProcessor(client, auditor, tag, description, targetField, modelId, inferenceConfigUpdate, fieldMap);
+            List<Map<String, Object>> inputs = ConfigurationUtils.readOptionalList(TYPE, tag, config, INPUT_OUTPUT);
+            boolean configuredWithInputFields = inputs != null;
+            if (configuredWithInputFields) {
+                // new style input/output configuration
+                var parsedInputs = parseInputFields(tag, inputs);
+
+                // validate incompatible settings are not present
+                String targetField = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, TARGET_FIELD);
+                if (targetField != null) {
+                    throw newConfigurationException(
+                        TYPE,
+                        tag,
+                        TARGET_FIELD,
+                        "option is incompatible with ["
+                            + INPUT_OUTPUT
+                            + "]."
+                            + " Use the ["
+                            + OUTPUT_FIELD
+                            + "] option to specify where to write the inference results to."
+                    );
+                }
+
+                if (inferenceConfigUpdate.getResultsField() != null) {
+                    throw newConfigurationException(
+                        TYPE,
+                        tag,
+                        null,
+                        "The ["
+                            + INFERENCE_CONFIG
+                            + "."
+                            + InferenceConfig.RESULTS_FIELD.getPreferredName()
+                            + "] setting is incompatible with using ["
+                            + INPUT_OUTPUT
+                            + "]. Prefer to use the ["
+                            + INPUT_OUTPUT
+                            + "."
+                            + OUTPUT_FIELD
+                            + "] option to specify where to write the inference results to."
+                    );
+                }
+
+                return fromInputFieldConfiguration(client, auditor, tag, description, modelId, inferenceConfigUpdate, parsedInputs);
+            } else {
+                // old style configuration with target field
+                String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag;
+                // If multiple inference processors are in the same pipeline, it is wise to tag them
+                // The tag will keep default value entries from stepping on each other
+                String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField);
+                Map<String, String> fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAP);
+                if (fieldMap == null) {
+                    fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
+                    // TODO Remove in 9?.x
+                    if (fieldMap != null) {
+                        LoggingDeprecationHandler.INSTANCE.logRenamedField(null, () -> null, FIELD_MAPPINGS, FIELD_MAP);
+                    }
+                }
+
+                if (fieldMap == null) {
+                    fieldMap = Collections.emptyMap();
+                }
+                return fromTargetFieldConfiguration(
+                    client,
+                    auditor,
+                    tag,
+                    description,
+                    targetField,
+                    modelId,
+                    inferenceConfigUpdate,
+                    fieldMap
+                );
+            }
         }
 
         // Package private for testing
@@ -374,5 +550,61 @@ public class InferenceProcessor extends AbstractProcessor {
                 );
             }
         }
+
+        List<InputConfig> parseInputFields(String tag, List<Map<String, Object>> inputs) {
+            if (inputs.isEmpty()) {
+                throw newConfigurationException(TYPE, tag, INPUT_OUTPUT, "cannot be empty at least one is required");
+            }
+            var inputNames = new HashSet<String>();
+            var outputNames = new HashSet<String>();
+            var parsedInputs = new ArrayList<InputConfig>();
+
+            for (var input : inputs) {
+                String inputField = ConfigurationUtils.readStringProperty(TYPE, tag, input, INPUT_FIELD);
+                String outputField = ConfigurationUtils.readStringProperty(TYPE, tag, input, OUTPUT_FIELD);
+
+                if (inputNames.add(inputField) == false) {
+                    throw duplicatedFieldNameError(INPUT_FIELD, inputField, tag);
+                }
+                if (outputNames.add(outputField) == false) {
+                    throw duplicatedFieldNameError(OUTPUT_FIELD, outputField, tag);
+                }
+
+                var outputPaths = extractBasePathAndFinalElement(outputField);
+
+                if (input.isEmpty()) {
+                    parsedInputs.add(new InputConfig(inputField, outputPaths.v1(), outputPaths.v2(), Map.of()));
+                } else {
+                    parsedInputs.add(new InputConfig(inputField, outputPaths.v1(), outputPaths.v2(), new HashMap<>(input)));
+                }
+            }
+
+            return parsedInputs;
+        }
+
+        private ElasticsearchException duplicatedFieldNameError(String property, String fieldName, String tag) {
+            return newConfigurationException(TYPE, tag, property, "names must be unique but [" + fieldName + "] is repeated");
+        }
+
+        /**
+         * {@code outputField} can be a dot '.' seperated path of elements.
+         * Extract the base path (everything before the last '.') and the final
+         * element.
+         * If {@code outputField} does not contain any dotted elements the base
+         * path is null.
+         *
+         * @param outputField The path to split
+         * @return Tuple of {@code <basePath, finalElement>}
+         */
+        static Tuple<String, String> extractBasePathAndFinalElement(String outputField) {
+            int lastIndex = outputField.lastIndexOf('.');
+            if (lastIndex < 0) {
+                return new Tuple<>(null, outputField);
+            } else {
+                return new Tuple<>(outputField.substring(0, lastIndex), outputField.substring(lastIndex + 1));
+            }
+        }
+
+        public record InputConfig(String inputField, String outputBasePath, String outputField, Map<String, Object> extras) {}
     }
 }

+ 360 - 20
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.ml.inference.ingest;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchParseException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.ClusterName;
@@ -26,7 +27,6 @@ import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.core.Tuple;
-import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.ingest.IngestMetadata;
 import org.elasticsearch.ingest.PipelineConfiguration;
 import org.elasticsearch.test.ESTestCase;
@@ -36,21 +36,33 @@ import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.ml.MlConfigVersion;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.junit.Before;
 
 import java.io.IOException;
 import java.net.InetAddress;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -59,7 +71,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.hamcrest.Matchers.hasSize;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -89,7 +105,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
         clusterService = new ClusterService(settings, clusterSettings, tp, null);
     }
 
-    public void testCreateProcessorWithTooManyExisting() throws Exception {
+    public void testCreateProcessorWithTooManyExisting() {
         Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
 
         includeNodeInfoValues.forEach(includeNodeInfo -> {
@@ -135,7 +151,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> config = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap()));
                 }
@@ -158,7 +174,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> config2 = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom"));
                 }
@@ -172,7 +188,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> config3 = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap());
                 }
@@ -185,7 +201,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
         });
     }
 
-    public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException {
+    public void testCreateProcessorWithTooOldMinNodeVersion() {
         Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
 
         includeNodeInfoValues.forEach(includeNodeInfo -> {
@@ -203,7 +219,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> regression = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(
                         InferenceProcessor.INFERENCE_CONFIG,
@@ -224,7 +240,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> classification = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(
                         InferenceProcessor.INFERENCE_CONFIG,
@@ -315,7 +331,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
 
             Map<String, Object> minimalConfig = new HashMap<>() {
                 {
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                 }
             };
@@ -342,7 +358,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> regression = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(
                         InferenceProcessor.INFERENCE_CONFIG,
@@ -351,12 +367,18 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
                 }
             };
 
-            processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
+            var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
+            assertEquals(includeNodeInfo, processor.getAuditor().includeNodeInfo());
+            assertFalse(processor.isConfiguredWithInputsFields());
+            assertEquals("my_model", processor.getModelId());
+            assertEquals("result", processor.getTargetField());
+            assertThat(processor.getFieldMap().entrySet(), empty());
+            assertNull(processor.getInputs());
 
             Map<String, Object> classification = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                     put(
                         InferenceProcessor.INFERENCE_CONFIG,
@@ -368,19 +390,79 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
                 }
             };
 
-            processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification);
+            processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification);
+            assertFalse(processor.isConfiguredWithInputsFields());
 
             Map<String, Object> mininmal = new HashMap<>() {
                 {
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "result");
                 }
             };
 
-            processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, mininmal);
+            processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, mininmal);
+            assertFalse(processor.isConfiguredWithInputsFields());
+            assertEquals("my_model", processor.getModelId());
+            assertEquals("result", processor.getTargetField());
+            assertNull(processor.getInputs());
         });
     }
 
+    public void testCreateProcessorWithFieldMap() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, false);
+
+        Map<String, Object> config = new HashMap<>() {
+            {
+                put(InferenceProcessor.FIELD_MAP, Collections.singletonMap("source", "dest"));
+                put(InferenceProcessor.MODEL_ID, "my_model");
+                put(InferenceProcessor.TARGET_FIELD, "result");
+                put(
+                    InferenceProcessor.INFERENCE_CONFIG,
+                    Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())
+                );
+            }
+        };
+
+        var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, config);
+        assertFalse(processor.isConfiguredWithInputsFields());
+        assertEquals("my_model", processor.getModelId());
+        assertEquals("result", processor.getTargetField());
+        assertNull(processor.getInputs());
+        var fieldMap = processor.getFieldMap();
+        assertThat(fieldMap.entrySet(), hasSize(1));
+        assertThat(fieldMap, hasEntry("source", "dest"));
+    }
+
+    public void testCreateProcessorWithInputOutputs() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, false);
+
+        Map<String, Object> config = new HashMap<>();
+        config.put(InferenceProcessor.MODEL_ID, "my_model");
+
+        Map<String, Object> input1 = new HashMap<>();
+        input1.put(InferenceProcessor.INPUT_FIELD, "in1");
+        input1.put(InferenceProcessor.OUTPUT_FIELD, "out1");
+        Map<String, Object> input2 = new HashMap<>();
+        input2.put(InferenceProcessor.INPUT_FIELD, "in2");
+        input2.put(InferenceProcessor.OUTPUT_FIELD, "out2");
+
+        List<Map<String, Object>> inputOutputs = new ArrayList<>();
+        inputOutputs.add(input1);
+        inputOutputs.add(input2);
+        config.put(InferenceProcessor.INPUT_OUTPUT, inputOutputs);
+
+        var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, config);
+        assertTrue(processor.isConfiguredWithInputsFields());
+        assertEquals("my_model", processor.getModelId());
+        var configuredInputs = processor.getInputs();
+        assertThat(configuredInputs, hasSize(2));
+        assertEquals(configuredInputs.get(0).inputField(), "in1");
+        assertEquals(configuredInputs.get(0).outputField(), "out1");
+        assertEquals(configuredInputs.get(1).inputField(), "in2");
+        assertEquals(configuredInputs.get(1).outputField(), "out2");
+
+    }
+
     public void testCreateProcessorWithDuplicateFields() {
         Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
 
@@ -395,7 +477,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
             Map<String, Object> regression = new HashMap<>() {
                 {
                     put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-                    put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
+                    put(InferenceProcessor.MODEL_ID, "my_model");
                     put(InferenceProcessor.TARGET_FIELD, "ml");
                     put(
                         InferenceProcessor.INFERENCE_CONFIG,
@@ -415,7 +497,41 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
         });
     }
 
-    public void testParseFromMap() {
+    public void testCreateProcessorWithIgnoreMissing() {
+        Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
+
+        includeNodeInfoValues.forEach(includeNodeInfo -> {
+            InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+                client,
+                clusterService,
+                Settings.EMPTY,
+                includeNodeInfo
+            );
+
+            Map<String, Object> regression = new HashMap<>() {
+                {
+                    put(InferenceProcessor.MODEL_ID, "my_model");
+                    put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
+                    put("ignore_missing", Boolean.TRUE);
+                    put(
+                        InferenceProcessor.INFERENCE_CONFIG,
+                        Collections.singletonMap(
+                            RegressionConfig.NAME.getPreferredName(),
+                            Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")
+                        )
+                    );
+                }
+            };
+
+            Exception ex = expectThrows(
+                Exception.class,
+                () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression)
+            );
+            assertThat(ex.getMessage(), equalTo("Invalid inference config. " + "More than one field is configured as [warning]"));
+        });
+    }
+
+    public void testParseInferenceConfigFromMap() {
         Set<Boolean> includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false));
 
         includeNodeInfoValues.forEach(includeNodeInfo -> {
@@ -433,6 +549,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
                 Tuple.tuple(PassThroughConfig.NAME, Map.of()),
                 Tuple.tuple(TextClassificationConfig.NAME, Map.of()),
                 Tuple.tuple(TextEmbeddingConfig.NAME, Map.of()),
+                Tuple.tuple(TextExpansionConfig.NAME, Map.of()),
                 Tuple.tuple(ZeroShotClassificationConfig.NAME, Map.of()),
                 Tuple.tuple(QuestionAnsweringConfig.NAME, Map.of("question", "What is the answer to life, the universe and everything?"))
             )) {
@@ -444,8 +561,231 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
         });
     }
 
-    private static ClusterState buildClusterState(Metadata metadata) {
-        return ClusterState.builder(new ClusterName("_name")).metadata(metadata).build();
+    public void testCreateProcessorWithIncompatibleTargetFieldSetting() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+            client,
+            clusterService,
+            Settings.EMPTY,
+            randomBoolean()
+        );
+
+        Map<String, Object> input = new HashMap<>() {
+            {
+                put(InferenceProcessor.INPUT_FIELD, "in");
+                put(InferenceProcessor.OUTPUT_FIELD, "out");
+            }
+        };
+
+        Map<String, Object> config = new HashMap<>() {
+            {
+                put(InferenceProcessor.MODEL_ID, "my_model");
+                put(InferenceProcessor.TARGET_FIELD, "ml");
+                put(InferenceProcessor.INPUT_OUTPUT, List.of(input));
+            }
+        };
+
+        ElasticsearchParseException ex = expectThrows(
+            ElasticsearchParseException.class,
+            () -> processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config)
+        );
+        assertThat(
+            ex.getMessage(),
+            containsString(
+                "[target_field] option is incompatible with [input_output]. Use the [output_field] option to specify where to write the "
+                    + "inference results to."
+            )
+        );
+    }
+
+    public void testCreateProcessorWithIncompatibleResultFieldSetting() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+            client,
+            clusterService,
+            Settings.EMPTY,
+            randomBoolean()
+        );
+
+        Map<String, Object> input = new HashMap<>() {
+            {
+                put(InferenceProcessor.INPUT_FIELD, "in");
+                put(InferenceProcessor.OUTPUT_FIELD, "out");
+            }
+        };
+
+        Map<String, Object> config = new HashMap<>() {
+            {
+                put(InferenceProcessor.MODEL_ID, "my_model");
+                put(InferenceProcessor.INPUT_OUTPUT, List.of(input));
+                put(
+                    InferenceProcessor.INFERENCE_CONFIG,
+                    Collections.singletonMap(
+                        TextExpansionConfig.NAME,
+                        Collections.singletonMap(TextExpansionConfig.RESULTS_FIELD.getPreferredName(), "foo")
+                    )
+                );
+            }
+        };
+
+        ElasticsearchParseException ex = expectThrows(
+            ElasticsearchParseException.class,
+            () -> processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config)
+        );
+        assertThat(
+            ex.getMessage(),
+            containsString(
+                "The [inference_config.results_field] setting is incompatible with using [input_output]. "
+                    + "Prefer to use the [input_output.output_field] option to specify where to write the inference results to."
+            )
+        );
+    }
+
+    public void testCreateProcessorWithInputFields() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+            client,
+            clusterService,
+            Settings.EMPTY,
+            randomBoolean()
+        );
+
+        Map<String, Object> inputMap = new HashMap<>() {
+            {
+                put(InferenceProcessor.INPUT_FIELD, "in");
+                put(InferenceProcessor.OUTPUT_FIELD, "out");
+            }
+        };
+
+        String inferenceConfigType = randomFrom(
+            ClassificationConfigUpdate.NAME.getPreferredName(),
+            RegressionConfigUpdate.NAME.getPreferredName(),
+            FillMaskConfigUpdate.NAME,
+            NerConfigUpdate.NAME,
+            PassThroughConfigUpdate.NAME,
+            QuestionAnsweringConfigUpdate.NAME,
+            TextClassificationConfigUpdate.NAME,
+            TextEmbeddingConfigUpdate.NAME,
+            TextExpansionConfigUpdate.NAME,
+            TextSimilarityConfigUpdate.NAME,
+            ZeroShotClassificationConfigUpdate.NAME
+        );
+
+        Map<String, Object> config = new HashMap<>() {
+            {
+                put(InferenceProcessor.MODEL_ID, "my_model");
+                put(InferenceProcessor.INPUT_OUTPUT, List.of(inputMap));
+                put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap()));
+            }
+        };
+        // create valid inference configs with required fields
+        if (inferenceConfigType.equals(TextSimilarityConfigUpdate.NAME)) {
+            var inferenceConfig = new HashMap<String, String>();
+            inferenceConfig.put(TextSimilarityConfig.TEXT.getPreferredName(), "text to compare");
+            config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, inferenceConfig));
+        } else if (inferenceConfigType.equals(QuestionAnsweringConfigUpdate.NAME)) {
+            var inferenceConfig = new HashMap<String, String>();
+            inferenceConfig.put(QuestionAnsweringConfig.QUESTION.getPreferredName(), "why is the sky blue?");
+            config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, inferenceConfig));
+        } else {
+            config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap()));
+        }
+
+        var inferenceProcessor = processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config);
+        assertEquals("my_model", inferenceProcessor.getModelId());
+        assertTrue(inferenceProcessor.isConfiguredWithInputsFields());
+
+        var inputs = inferenceProcessor.getInputs();
+        assertThat(inputs, hasSize(1));
+        assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", null, "out", Map.of()));
+
+        assertNull(inferenceProcessor.getFieldMap());
+        assertNull(inferenceProcessor.getTargetField());
+    }
+
+    public void testParsingInputFields() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+            client,
+            clusterService,
+            Settings.EMPTY,
+            randomBoolean()
+        );
+
+        int numInputs = randomIntBetween(1, 3);
+        List<Map<String, Object>> inputs = new ArrayList<>();
+        for (int i = 0; i < numInputs; i++) {
+            Map<String, Object> inputMap = new HashMap<>();
+            inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i);
+            inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out." + i);
+            inputs.add(inputMap);
+        }
+
+        var parsedInputs = processorFactory.parseInputFields("my_processor", inputs);
+        assertThat(parsedInputs, hasSize(numInputs));
+        for (int i = 0; i < numInputs; i++) {
+            assertEquals(new InferenceProcessor.Factory.InputConfig("in" + i, "out", Integer.toString(i), Map.of()), parsedInputs.get(i));
+        }
+    }
+
+    public void testParsingInputFieldsDuplicateFieldNames() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+            client,
+            clusterService,
+            Settings.EMPTY,
+            randomBoolean()
+        );
+
+        int numInputs = 2;
+        {
+            List<Map<String, Object>> inputs = new ArrayList<>();
+            for (int i = 0; i < numInputs; i++) {
+                Map<String, Object> inputMap = new HashMap<>();
+                inputMap.put(InferenceProcessor.INPUT_FIELD, "in");
+                inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out" + i);
+                inputs.add(inputMap);
+            }
+
+            var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", inputs));
+            assertThat(e.getMessage(), containsString("[input_field] names must be unique but [in] is repeated"));
+        }
+
+        {
+            List<Map<String, Object>> inputs = new ArrayList<>();
+            for (int i = 0; i < numInputs; i++) {
+                Map<String, Object> inputMap = new HashMap<>();
+                inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i);
+                inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out");
+                inputs.add(inputMap);
+            }
+
+            var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", inputs));
+            assertThat(e.getMessage(), containsString("[output_field] names must be unique but [out] is repeated"));
+        }
+    }
+
+    public void testExtractBasePathAndFinalElement() {
+        {
+            String path = "foo.bar.result";
+            var extractedPaths = InferenceProcessor.Factory.extractBasePathAndFinalElement(path);
+            assertEquals("foo.bar", extractedPaths.v1());
+            assertEquals("result", extractedPaths.v2());
+        }
+
+        {
+            String path = "result";
+            var extractedPaths = InferenceProcessor.Factory.extractBasePathAndFinalElement(path);
+            assertNull(extractedPaths.v1());
+            assertEquals("result", extractedPaths.v2());
+        }
+    }
+
+    public void testParsingInputFieldsGivenNoInputs() {
+        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(
+            client,
+            clusterService,
+            Settings.EMPTY,
+            randomBoolean()
+        );
+
+        var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", List.of()));
+        assertThat(e.getMessage(), containsString("[input_output] cannot be empty at least one is required"));
     }
 
     private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException {
@@ -513,7 +853,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
     private static Map<String, Object> inferenceProcessorForModel(String modelId) {
         return Collections.singletonMap(InferenceProcessor.TYPE, new HashMap<>() {
             {
-                put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId);
+                put(InferenceProcessor.MODEL_ID, modelId);
                 put(
                     InferenceProcessor.INFERENCE_CONFIG,
                     Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())

+ 85 - 12
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureIm
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@@ -56,7 +57,7 @@ public class InferenceProcessorTests extends ESTestCase {
 
     public void testMutateDocumentWithClassification() {
         String targetField = "ml.my_processor";
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -89,7 +90,7 @@ public class InferenceProcessorTests extends ESTestCase {
     public void testMutateDocumentClassificationTopNClasses() {
         ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null, null);
         ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null, PredictionFieldType.STRING);
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -126,7 +127,7 @@ public class InferenceProcessorTests extends ESTestCase {
     public void testMutateDocumentClassificationFeatureInfluence() {
         ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2, PredictionFieldType.STRING);
         ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2, null);
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -180,7 +181,7 @@ public class InferenceProcessorTests extends ESTestCase {
     public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
         ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops", null, PredictionFieldType.STRING);
         ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null, null);
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -217,7 +218,7 @@ public class InferenceProcessorTests extends ESTestCase {
     public void testMutateDocumentRegression() {
         RegressionConfig regressionConfig = new RegressionConfig("foo");
         RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null);
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -244,7 +245,7 @@ public class InferenceProcessorTests extends ESTestCase {
     public void testMutateDocumentRegressionWithTopFeatures() {
         RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
         RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2);
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -280,7 +281,7 @@ public class InferenceProcessorTests extends ESTestCase {
         String modelId = "model";
         Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
 
-        InferenceProcessor processor = new InferenceProcessor(
+        InferenceProcessor processor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -320,7 +321,7 @@ public class InferenceProcessorTests extends ESTestCase {
         fieldMapping.put("categorical", "new_categorical");
         fieldMapping.put("_ingest._value", "metafield");
 
-        InferenceProcessor processor = new InferenceProcessor(
+        InferenceProcessor processor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -362,7 +363,7 @@ public class InferenceProcessorTests extends ESTestCase {
         fieldMapping.put("value2", "new_value2");
         fieldMapping.put("categorical.bar", "new_categorical");
 
-        InferenceProcessor processor = new InferenceProcessor(
+        InferenceProcessor processor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -390,7 +391,7 @@ public class InferenceProcessorTests extends ESTestCase {
 
     public void testHandleResponseLicenseChanged() {
         String targetField = "regression_value";
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -440,7 +441,7 @@ public class InferenceProcessorTests extends ESTestCase {
 
     public void testMutateDocumentWithWarningResult() {
         String targetField = "regression_value";
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -468,7 +469,7 @@ public class InferenceProcessorTests extends ESTestCase {
     public void testMutateDocumentWithModelIdResult() {
         String modelAlias = "special_model";
         String modelId = "regression-123";
-        InferenceProcessor inferenceProcessor = new InferenceProcessor(
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration(
             client,
             auditor,
             "my_processor",
@@ -491,4 +492,76 @@ public class InferenceProcessorTests extends ESTestCase {
         assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo(modelId));
     }
+
+    public void testMutateDocumentWithInputFields() {
+        String modelId = "regression-123";
+        List<InferenceProcessor.Factory.InputConfig> inputs = new ArrayList<>();
+        inputs.add(new InferenceProcessor.Factory.InputConfig("body", null, "body_result", Map.of()));
+        inputs.add(new InferenceProcessor.Factory.InputConfig("content", null, "content_result", Map.of()));
+
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration(
+            client,
+            auditor,
+            "my_processor_tag",
+            "description",
+            modelId,
+            new RegressionConfigUpdate("foo", null),
+            inputs
+        );
+
+        IngestDocument document = TestIngestDocument.emptyIngestDocument();
+
+        InferModelAction.Response response = new InferModelAction.Response(
+            List.of(new RegressionInferenceResults(0.7, "ignore"), new RegressionInferenceResults(1.0, "ignore")),
+            modelId,
+            true
+        );
+        inferenceProcessor.mutateDocument(response, document);
+
+        assertThat(document.getFieldValue("body_result", Double.class), equalTo(0.7));
+        assertThat(document.getFieldValue("content_result", Double.class), equalTo(1.0));
+    }
+
+    public void testMutateDocumentWithInputFieldsNested() {
+        String modelId = "elser";
+        List<InferenceProcessor.Factory.InputConfig> inputs = new ArrayList<>();
+        inputs.add(new InferenceProcessor.Factory.InputConfig("body", "ml.results", "body_tokens", Map.of()));
+        inputs.add(new InferenceProcessor.Factory.InputConfig("content", "ml.results", "content_tokens", Map.of()));
+
+        InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration(
+            client,
+            auditor,
+            "my_processor_tag",
+            "description",
+            modelId,
+            new RegressionConfigUpdate("foo", null),
+            inputs
+        );
+
+        IngestDocument document = TestIngestDocument.emptyIngestDocument();
+
+        var teResult1 = TextExpansionResultsTests.createRandomResults();
+        var teResult2 = TextExpansionResultsTests.createRandomResults();
+        InferModelAction.Response response = new InferModelAction.Response(List.of(teResult1, teResult2), modelId, true);
+        inferenceProcessor.mutateDocument(response, document);
+
+        var bodyTokens = document.getFieldValue("ml.results.body_tokens", HashMap.class);
+        assertEquals(teResult1.getWeightedTokens().size(), bodyTokens.entrySet().size());
+        if (teResult1.getWeightedTokens().isEmpty() == false) {
+            assertEquals(
+                (float) bodyTokens.get(teResult1.getWeightedTokens().get(0).token()),
+                teResult1.getWeightedTokens().get(0).weight(),
+                0.001
+            );
+        }
+        var contentTokens = document.getFieldValue("ml.results.content_tokens", HashMap.class);
+        assertEquals(teResult2.getWeightedTokens().size(), contentTokens.entrySet().size());
+        if (teResult2.getWeightedTokens().isEmpty() == false) {
+            assertEquals(
+                (float) contentTokens.get(teResult2.getWeightedTokens().get(0).token()),
+                teResult2.getWeightedTokens().get(0).weight(),
+                0.001
+            );
+        }
+    }
 }