Browse Source

[LTR] Do not add fields extracted using a query to the FieldValueFeatureExtractor (#109437)

Aurélien FOUCRET 1 year ago
parent
commit
248b045d70

+ 12 - 20
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -47,7 +47,6 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.time.Instant;
-import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collections;
 import java.util.HashMap;
@@ -236,7 +235,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         this.description = description;
         this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
         this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
-        this.input = ExceptionsHelper.requireNonNull(handleDefaultInput(input, modelType), INPUT);
+        this.input = ExceptionsHelper.requireNonNull(handleDefaultInput(input, inferenceConfig, modelType), INPUT);
         if (ExceptionsHelper.requireNonNull(modelSize, MODEL_SIZE_BYTES) < 0) {
             throw new IllegalArgumentException("[" + MODEL_SIZE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
         }
@@ -256,11 +255,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         this.prefixStrings = prefixStrings;
     }
 
-    private static TrainedModelInput handleDefaultInput(TrainedModelInput input, TrainedModelType modelType) {
-        if (modelType == null) {
-            return input;
-        }
-        return input == null ? modelType.getDefaultInput() : input;
+    private static TrainedModelInput handleDefaultInput(
+        TrainedModelInput input,
+        InferenceConfig inferenceConfig,
+        TrainedModelType modelType
+    ) {
+        return input == null && inferenceConfig != null ? inferenceConfig.getDefaultInput(modelType) : input;
     }
 
     public TrainedModelConfig(StreamInput in) throws IOException {
@@ -963,20 +963,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
                     break;
                 }
             }
-            if (input != null && input.getFieldNames().isEmpty()) {
-                validationException = addValidationError("[input.field_names] must not be empty", validationException);
-            }
-            if (input != null
-                && input.getFieldNames()
-                    .stream()
-                    .filter(s -> s.contains("."))
-                    .flatMap(s -> Arrays.stream(Strings.delimitedListToStringArray(s, ".")))
-                    .anyMatch(String::isEmpty)) {
-                validationException = addValidationError(
-                    "[input.field_names] must only contain valid dot delimited field names",
-                    validationException
-                );
+
+            // Delegate input validation to the inference config.
+            if (inferenceConfig != null) {
+                validationException = inferenceConfig.validateTrainedModelInput(input, forCreation, validationException);
             }
+
             if (forCreation) {
                 validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
                 validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);

+ 42 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

@@ -8,12 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xpack.core.ml.MlConfigVersion;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
+import java.util.Arrays;
+
+import static org.elasticsearch.action.ValidateActions.addValidationError;
+
 public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable {
 
     String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
@@ -65,6 +74,39 @@ public interface InferenceConfig extends NamedXContentObject, VersionedNamedWrit
         return false;
     }
 
+    @Nullable
+    default TrainedModelInput getDefaultInput(TrainedModelType modelType) {
+        if (modelType == null) {
+            return null;
+        }
+        return modelType.getDefaultInput();
+    }
+
+    default ActionRequestValidationException validateTrainedModelInput(
+        TrainedModelInput input,
+        boolean forCreation,
+        ActionRequestValidationException validationException
+    ) {
+
+        if (input != null && input.getFieldNames().isEmpty()) {
+            validationException = addValidationError("[input.field_names] must not be empty", validationException);
+        }
+
+        if (input != null
+            && input.getFieldNames()
+                .stream()
+                .filter(s -> s.contains("."))
+                .flatMap(s -> Arrays.stream(Strings.delimitedListToStringArray(s, ".")))
+                .anyMatch(String::isEmpty)) {
+            validationException = addValidationError(
+                "[input.field_names] must only contain valid dot delimited field names",
+                validationException
+            );
+        }
+
+        return validationException;
+    }
+
     default ElasticsearchStatusException incompatibleUpdateException(String updateName) {
         throw ExceptionsHelper.badRequestException(
             "Inference config of type [{}] can not be updated with a inference request of type [{}]",

+ 25 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfig.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -17,6 +18,8 @@ import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.MlConfigVersion;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearningToRankFeatureExtractorBuilder;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
@@ -30,6 +33,8 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.action.ValidateActions.addValidationError;
+
 public class LearningToRankConfig extends RegressionConfig implements Rewriteable<LearningToRankConfig> {
 
     public static final ParseField NAME = new ParseField("learning_to_rank");
@@ -43,6 +48,8 @@ public class LearningToRankConfig extends RegressionConfig implements Rewriteabl
     private static final ObjectParser<LearningToRankConfig.Builder, Boolean> LENIENT_PARSER = createParser(true);
     private static final ObjectParser<LearningToRankConfig.Builder, Boolean> STRICT_PARSER = createParser(false);
 
+    private static final TrainedModelInput DEFAULT_INPUT = new TrainedModelInput(List.of());
+
     private static ObjectParser<LearningToRankConfig.Builder, Boolean> createParser(boolean lenient) {
         ObjectParser<LearningToRankConfig.Builder, Boolean> parser = new ObjectParser<>(
             NAME.getPreferredName(),
@@ -237,6 +244,24 @@ public class LearningToRankConfig extends RegressionConfig implements Rewriteabl
         return this;
     }
 
+    @Override
+    public TrainedModelInput getDefaultInput(TrainedModelType modelType) {
+        return DEFAULT_INPUT;
+    }
+
+    @Override
+    public ActionRequestValidationException validateTrainedModelInput(
+        TrainedModelInput input,
+        boolean forCreation,
+        ActionRequestValidationException validationException
+    ) {
+        if (forCreation && input != null && input.getFieldNames().isEmpty() == false) {
+            return addValidationError("cannot specify [input.field_names] for a model of type [learning_to_rank]", validationException);
+        }
+
+        return validationException;
+    }
+
     public static class Builder {
         private Integer numTopFeatureImportanceValues;
         private List<LearningToRankFeatureExtractorBuilder> learningToRankFeatureExtractorBuilders;

+ 24 - 12
x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java

@@ -31,10 +31,33 @@ public class MlLearningToRankRescorerIT extends ESRestTestCase {
         putLearningToRankModel(MODEL_ID, """
             {
               "description": "super complex model for tests",
-              "input": { "field_names": ["cost", "product"] },
               "inference_config": {
                 "learning_to_rank": {
                   "feature_extractors": [
+                    {
+                      "query_extractor": {
+                        "feature_name": "cost",
+                        "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return doc['cost'].value;"}}}
+                      }
+                    },
+                    {
+                      "query_extractor": {
+                        "feature_name": "type_tv",
+                        "query": {"constant_score": {"filter": {"term": { "product": "TV" }}, "boost": 1.0}}
+                      }
+                    },
+                    {
+                      "query_extractor": {
+                        "feature_name": "type_vcr",
+                        "query": {"constant_score": {"filter": {"term": { "product": "VCR" }}, "boost": 1.0}}
+                      }
+                    },
+                    {
+                      "query_extractor": {
+                        "feature_name": "type_laptop",
+                        "query": {"constant_score": {"filter": {"term": { "product": "Laptop" }}, "boost": 1.0}}
+                      }
+                    },
                     {
                         "query_extractor": {
                             "feature_name": "two",
@@ -51,16 +74,6 @@ public class MlLearningToRankRescorerIT extends ESRestTestCase {
                 }
               },
               "definition": {
-                "preprocessors" : [{
-                  "one_hot_encoding": {
-                    "field": "product",
-                    "hot_map": {
-                      "TV": "type_tv",
-                      "VCR": "type_vcr",
-                      "Laptop": "type_laptop"
-                    }
-                  }
-                }],
                 "trained_model": {
                   "ensemble": {
                     "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"],
@@ -351,7 +364,6 @@ public class MlLearningToRankRescorerIT extends ESRestTestCase {
         deleteLearningToRankModel(MODEL_ID);
         putLearningToRankModel(MODEL_ID, """
             {
-              "input": { "field_names": ["cost"] },
               "inference_config": {
                 "learning_to_rank": {
                   "feature_extractors": [

+ 1 - 0
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -181,6 +181,7 @@ tasks.named("yamlRestTest").configure {
     'ml/inference_crud/Test put model model aliases with nlp model',
     'ml/inference_processor/Test create processor with missing mandatory fields',
     'ml/learning_to_rank_rescorer/Test rescore with missing model',
+    'ml/learning_to_rank_rescorer/Test model input validation',
     'ml/inference_stats_crud/Test get stats given missing trained model',
     'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
     'ml/jobs_crud/Test cannot create job with model snapshot id set',

+ 135 - 120
x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java

@@ -31,151 +31,166 @@ public class LearningToRankRescorerIT extends InferenceTestCase {
         putRegressionModel(MODEL_ID, """
             {
               "description": "super complex model for tests",
-              "input": {"field_names": ["cost", "product"]},
               "inference_config": {
                 "learning_to_rank": {
                   "feature_extractors": [
                     {
                       "query_extractor": {
-                        "feature_name": "two",
-                        "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}}
+                        "feature_name": "cost",
+                        "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return doc['cost'].value;"}}}
                       }
                     },
                     {
                       "query_extractor": {
-                        "feature_name": "product_bm25",
-                        "query": {"term": {"product": "{{keyword}}"}}
+                        "feature_name": "type_tv",
+                        "query": {"constant_score": {"filter": {"term": { "product": "TV" }}, "boost": 1.0}}
                       }
+                    },
+                    {
+                      "query_extractor": {
+                        "feature_name": "type_vcr",
+                        "query": {"constant_score": {"filter": {"term": { "product": "VCR" }}, "boost": 1.0}}
+                      }
+                    },
+                    {
+                      "query_extractor": {
+                        "feature_name": "type_laptop",
+                        "query": {"constant_score": {"filter": {"term": { "product": "Laptop" }}, "boost": 1.0}}
+                      }
+                    },
+                    {
+                        "query_extractor": {
+                            "feature_name": "two",
+                            "query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 2.0;" } } }
+                        }
+                    },
+                    {
+                        "query_extractor": {
+                            "feature_name": "product_bm25",
+                            "query": { "term": { "product": "{{keyword}}" } }
+                        }
                     }
                   ]
                 }
               },
               "definition": {
-                "preprocessors" : [{
-                  "one_hot_encoding": {
-                    "field": "product",
-                    "hot_map": {
-                      "TV": "type_tv",
-                      "VCR": "type_vcr",
-                      "Laptop": "type_laptop"
-                    }
-                  }
-                }],
                 "trained_model": {
                   "ensemble": {
                     "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"],
                     "target_type": "regression",
                     "trained_models": [
-                      {
-                        "tree": {
-                          "feature_names": ["cost"],
-                          "tree_structure": [
-                            {
-                              "node_index": 0,
-                              "split_feature": 0,
-                              "split_gain": 12,
-                              "threshold": 400,
-                              "decision_type": "lte",
-                              "default_left": true,
-                              "left_child": 1,
-                              "right_child": 2
-                            },
-                            {
-                              "node_index": 1,
-                              "leaf_value": 5.0
-                            },
-                            {
-                              "node_index": 2,
-                              "leaf_value": 2.0
-                            }
-                          ],
-                          "target_type": "regression"
+                    {
+                      "tree": {
+                        "feature_names": [
+                          "cost"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 400,
+                          "decision_type": "lte",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 5.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 2.0
                         }
-                      },
-                      {
-                        "tree": {
-                          "feature_names": [
-                            "type_tv"
-                          ],
-                          "tree_structure": [
-                            {
-                              "node_index": 0,
-                              "split_feature": 0,
-                              "split_gain": 12,
-                              "threshold": 1,
-                              "decision_type": "lt",
-                              "default_left": true,
-                              "left_child": 1,
-                              "right_child": 2
-                            },
-                            {
-                              "node_index": 1,
-                              "leaf_value": 1.0
-                            },
-                            {
-                              "node_index": 2,
-                              "leaf_value": 12.0
-                            }
-                          ],
-                          "target_type": "regression"
+                        ],
+                        "target_type": "regression"
+                      }
+                    },
+                    {
+                      "tree": {
+                        "feature_names": [
+                          "type_tv"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 1,
+                          "decision_type": "lt",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 1.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 12.0
                         }
-                      },
-                      {
-                        "tree": {
-                          "feature_names": [
-                            "two"
-                          ],
-                          "tree_structure": [
-                            {
-                              "node_index": 0,
-                              "split_feature": 0,
-                              "split_gain": 12,
-                              "threshold": 1,
-                              "decision_type": "lt",
-                              "default_left": true,
-                              "left_child": 1,
-                              "right_child": 2
-                            },
-                            {
-                              "node_index": 1,
-                              "leaf_value": 1.0
-                            },
-                            {
-                              "node_index": 2,
-                              "leaf_value": 2.0
-                            }
-                          ],
-                          "target_type": "regression"
+                        ],
+                        "target_type": "regression"
+                      }
+                    },
+                     {
+                      "tree": {
+                        "feature_names": [
+                          "two"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 1,
+                          "decision_type": "lt",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 1.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 2.0
                         }
-                      },
-                      {
-                        "tree": {
-                          "feature_names": [
-                            "product_bm25"
-                          ],
-                          "tree_structure": [
-                            {
-                              "node_index": 0,
-                              "split_feature": 0,
-                              "split_gain": 12,
-                              "threshold": 1,
-                              "decision_type": "lt",
-                              "default_left": true,
-                              "left_child": 1,
-                              "right_child": 2
-                            },
-                            {
-                              "node_index": 1,
-                              "leaf_value": 1.0
-                            },
-                            {
-                              "node_index": 2,
-                              "leaf_value": 4.0
-                            }
-                          ],
-                          "target_type": "regression"
+                        ],
+                        "target_type": "regression"
+                      }
+                    },
+                     {
+                      "tree": {
+                        "feature_names": [
+                          "product_bm25"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 1,
+                          "decision_type": "lt",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 1.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 4.0
                         }
+                        ],
+                        "target_type": "regression"
                       }
+                    }
                     ]
                   }
                 }

+ 12 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java

@@ -24,6 +24,8 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
+import static java.util.function.Predicate.not;
+
 public class LearningToRankRescorerContext extends RescoreContext {
 
     final SearchExecutionContext executionContext;
@@ -52,12 +54,9 @@ public class LearningToRankRescorerContext extends RescoreContext {
 
     List<FeatureExtractor> buildFeatureExtractors(IndexSearcher searcher) throws IOException {
         assert this.regressionModelDefinition != null && this.learningToRankConfig != null;
+
         List<FeatureExtractor> featureExtractors = new ArrayList<>();
-        if (this.regressionModelDefinition.inputFields().isEmpty() == false) {
-            featureExtractors.add(
-                new FieldValueFeatureExtractor(new ArrayList<>(this.regressionModelDefinition.inputFields()), this.executionContext)
-            );
-        }
+
         List<Weight> weights = new ArrayList<>();
         List<String> queryFeatureNames = new ArrayList<>();
         for (LearningToRankFeatureExtractorBuilder featureExtractorBuilder : learningToRankConfig.getFeatureExtractorBuilders()) {
@@ -72,6 +71,14 @@ public class LearningToRankRescorerContext extends RescoreContext {
             featureExtractors.add(new QueryFeatureExtractor(queryFeatureNames, weights));
         }
 
+        List<String> fieldValueExtractorFields = this.regressionModelDefinition.inputFields()
+            .stream()
+            .filter(not(queryFeatureNames::contains))
+            .toList();
+        if (fieldValueExtractorFields.isEmpty() == false) {
+            featureExtractors.add(new FieldValueFeatureExtractor(fieldValueExtractorFields, this.executionContext));
+        }
+
         return featureExtractors;
     }
 

+ 43 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java

@@ -41,6 +41,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.in;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
@@ -193,8 +194,7 @@ public class LearningToRankRescorerBuilderRewriteTests extends AbstractBuilderTe
 
     public void testBuildContext() throws Exception {
         LocalModel localModel = mock(LocalModel.class);
-        List<String> inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME);
-        when(localModel.inputFields()).thenReturn(inputFields);
+        when(localModel.inputFields()).thenReturn(GOOD_MODEL_CONFIG.getInput().getFieldNames());
 
         IndexSearcher searcher = mock(IndexSearcher.class);
         doAnswer(invocation -> invocation.getArgument(0)).when(searcher).rewrite(any(Query.class));
@@ -211,11 +211,48 @@ public class LearningToRankRescorerBuilderRewriteTests extends AbstractBuilderTe
         assertNotNull(rescoreContext);
         assertThat(rescoreContext.getWindowSize(), equalTo(20));
         List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher());
-        assertThat(featureExtractors, hasSize(2));
-        assertThat(
-            featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(),
-            containsInAnyOrder("feature_1", "feature_2", DOUBLE_FIELD_NAME, INT_FIELD_NAME)
+        assertThat(featureExtractors, hasSize(1));
+
+        FeatureExtractor queryExtractor = featureExtractors.get(0);
+        assertThat(queryExtractor, instanceOf(QueryFeatureExtractor.class));
+        assertThat(queryExtractor.featureNames(), hasSize(2));
+        assertThat(queryExtractor.featureNames(), containsInAnyOrder("feature_1", "feature_2"));
+    }
+
+    public void testLegacyFieldValueExtractorBuildContext() throws Exception {
+        // Models created before 8.15 have been saved with input fields.
+        // We check field value extractors are created and the deduplication is done correctly.
+        LocalModel localModel = mock(LocalModel.class);
+        when(localModel.inputFields()).thenReturn(List.of("feature_1", "field_1", "field_2"));
+
+        IndexSearcher searcher = mock(IndexSearcher.class);
+        doAnswer(invocation -> invocation.getArgument(0)).when(searcher).rewrite(any(Query.class));
+        SearchExecutionContext context = createSearchExecutionContext(searcher);
+
+        LearningToRankRescorerBuilder rescorerBuilder = new LearningToRankRescorerBuilder(
+            localModel,
+            (LearningToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(),
+            null,
+            mock(LearningToRankService.class)
         );
+
+        LearningToRankRescorerContext rescoreContext = rescorerBuilder.innerBuildContext(20, context);
+        assertNotNull(rescoreContext);
+        assertThat(rescoreContext.getWindowSize(), equalTo(20));
+        List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher());
+
+        assertThat(featureExtractors, hasSize(2));
+
+        FeatureExtractor queryExtractor = featureExtractors.stream().filter(fe -> fe instanceof QueryFeatureExtractor).findFirst().get();
+        assertThat(queryExtractor.featureNames(), hasSize(2));
+        assertThat(queryExtractor.featureNames(), containsInAnyOrder("feature_1", "feature_2"));
+
+        FeatureExtractor fieldValueExtractor = featureExtractors.stream()
+            .filter(fe -> fe instanceof FieldValueFeatureExtractor)
+            .findFirst()
+            .get();
+        assertThat(fieldValueExtractor.featureNames(), hasSize(2));
+        assertThat(fieldValueExtractor.featureNames(), containsInAnyOrder("field_1", "field_2"));
     }
 
     private LearningToRankRescorerBuilder rewriteAndFetch(

+ 2 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java

@@ -50,11 +50,10 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 
 public class LearningToRankServiceTests extends ESTestCase {
-    public static final String GOOD_MODEL = "inferenceEntityId";
-    public static final String BAD_MODEL = "badModel";
+    public static final String GOOD_MODEL = "inference-entity-id";
+    public static final String BAD_MODEL = "bad-model";
     public static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder()
         .setModelId(GOOD_MODEL)
-        .setInput(new TrainedModelInput(List.of("field1", "field2")))
         .setEstimatedOperations(1)
         .setModelSize(2)
         .setModelType(TrainedModelType.TREE_ENSEMBLE)

+ 88 - 11
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/learning_to_rank_rescorer.yml

@@ -9,22 +9,37 @@ setup:
         body: >
           {
             "description": "super complex model for tests",
-            "input": {"field_names": ["cost", "product"]},
             "inference_config": {
               "learning_to_rank": {
+                "feature_extractors": [
+                  {
+                    "query_extractor": {
+                      "feature_name": "cost",
+                      "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return doc['cost'].value;"}}}
+                    }
+                  },
+                  {
+                    "query_extractor": {
+                      "feature_name": "type_tv",
+                      "query": {"term": {"product":  "TV"}}
+                    }
+                  },
+                  {
+                    "query_extractor": {
+                      "feature_name": "type_vcr",
+                      "query": {"term": {"product":  "VCR"}}
+                    }
+                  },
+                  {
+                    "query_extractor": {
+                      "feature_name": "type_laptop",
+                      "query": {"term": {"product":  "Laptop"}}
+                    }
+                  }
+                ]
               }
             },
             "definition": {
-              "preprocessors" : [{
-                "one_hot_encoding": {
-                  "field": "product",
-                  "hot_map": {
-                    "TV": "type_tv",
-                    "VCR": "type_vcr",
-                    "Laptop": "type_laptop"
-                  }
-                }
-              }],
               "trained_model": {
                 "ensemble": {
                   "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop"],
@@ -246,3 +261,65 @@ setup:
             }
           }
   - length: { hits.hits: 0 }
+---
+"Test model input validation":
+  - skip:
+      features: headers
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      catch: bad_request
+      ml.put_trained_model:
+        model_id: bad-model
+        body: >
+          {
+            "description": "a bad model",
+            "input": {
+              "field_names": ["cost"]
+            },
+            "inference_config": {
+              "learning_to_rank": { }
+            },
+            "definition": {
+              "trained_model": {
+                "ensemble": {
+                  "feature_names": ["cost"],
+                  "target_type": "regression",
+                  "trained_models": [
+                    {
+                      "tree": {
+                        "feature_names": [
+                          "cost"
+                        ],
+                        "tree_structure": [
+                          {
+                            "node_index": 0,
+                            "split_feature": 0,
+                            "split_gain": 12,
+                            "threshold": 400,
+                            "decision_type": "lte",
+                            "default_left": true,
+                            "left_child": 1,
+                            "right_child": 2
+                          },
+                          {
+                            "node_index": 1,
+                            "leaf_value": 5.0
+                          },
+                          {
+                            "node_index": 2,
+                            "leaf_value": 2.0
+                          }
+                        ],
+                        "target_type": "regression"
+                      }
+                    }
+                  ]
+                }
+              }
+            }
+          }
+
+  - match: { status: 400 }
+  - match: { error.root_cause.0.type: "action_request_validation_exception" }
+  - match: { error.root_cause.0.reason: "Validation Failed: 1: cannot specify [input.field_names] for a model of type [learning_to_rank];" }