Browse Source

[ML] add ltr feature extractor for queries (#97331)

This commit builds on our suite of LTR feature extractors. Now, instead
of just runtime-fields and document fields, query interaction features
are supported.

A user can store a query (for features that don't need `_search` time
information) or supply one via the `inference_config` object in the
`inference_rescorer` object. 

An example stored configuration:

```
{
//<snip>
"input": {"field_names": ["cost", "product"]},
"inference_config": {
  "learn_to_rank": {
    "feature_extractors": [{"query_extractor": {"feature_name": "two", "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}}}}]
  }
}
//</snip>
}
```

The above will provide the document/runtime fields `cost` and `product`
to the model at inference time, along with an extracted feature called
`"two"`.

However the more general usage would be features required at search
time.

```
POST _search
{
  "query": {"match": {"field": {"query": "quick brown fox"}}},
  "rescorer": {
    "window_size": 10,
    "inference": {
      "model_id": "ltr_model",
      "inference_config": {
        "learn_to_rank": {"feature_extractors":[{"query_extractor": {"feature_name": "field_bm25", "query": {"match": {"field": {"query": "quick brown fox"}}}}]}
      }
    }
  }
}
```

All queries are grabbed as early as possible within the search request.
This way the appropriate rewrites can occur. Additionally, the parsed
queries from the rescorer are provided via DFS, so that term-stats can
be gathered without any additional configuration by the user. This means
that terms only used via a `feature_extractor` can have accurate term
statistics when using DFS.
Benjamin Trent 2 years ago
parent
commit
5625ebe74a
22 changed files with 1525 additions and 258 deletions
  1. 1 0
      x-pack/plugin/core/src/main/java/module-info.java
  2. 18 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java
  3. 22 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java
  4. 21 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java
  5. 6 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java
  6. 11 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/LearnToRankFeatureExtractorBuilder.java
  7. 111 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/QueryExtractorBuilder.java
  8. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  9. 16 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java
  10. 13 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigTests.java
  11. 22 18
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java
  12. 81 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/QueryExtractorBuilderTests.java
  13. 181 87
      x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java
  14. 169 89
      x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java
  15. 19 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java
  16. 13 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java
  17. 295 34
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java
  18. 43 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java
  19. 78 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/QueryFeatureExtractor.java
  20. 175 10
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java
  21. 101 7
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java
  22. 127 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/QueryFeatureExtractorTests.java

+ 1 - 0
x-pack/plugin/core/src/main/java/module-info.java

@@ -98,6 +98,7 @@ module org.elasticsearch.xcore {
     exports org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
     exports org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
     exports org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;
+    exports org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr;
     exports org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
     exports org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
     exports org.elasticsearch.xpack.core.ml.inference.trainedmodel;

+ 18 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java

@@ -15,6 +15,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -52,6 +54,14 @@ public class MlLTRNamedXContentProvider implements NamedXContentProvider {
                 LearnToRankConfigUpdate::fromXContentStrict
             )
         );
+        // LTR extractors
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                LearnToRankFeatureExtractorBuilder.class,
+                QueryExtractorBuilder.NAME,
+                QueryExtractorBuilder::fromXContent
+            )
+        );
         return namedXContent;
     }
 
@@ -69,6 +79,14 @@ public class MlLTRNamedXContentProvider implements NamedXContentProvider {
                 LearnToRankConfigUpdate::new
             )
         );
+        // LTR Extractors
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                LearnToRankFeatureExtractorBuilder.class,
+                QueryExtractorBuilder.NAME.getPreferredName(),
+                QueryExtractorBuilder::new
+            )
+        );
         return namedWriteables;
     }
 }

+ 22 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java

@@ -10,6 +10,8 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -18,12 +20,13 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFea
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
 
-public class LearnToRankConfig extends RegressionConfig {
+public class LearnToRankConfig extends RegressionConfig implements Rewriteable<LearnToRankConfig> {
 
     public static final ParseField NAME = new ParseField("learn_to_rank");
     static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersion.current();
@@ -171,6 +174,24 @@ public class LearnToRankConfig extends RegressionConfig {
         return MIN_SUPPORTED_TRANSPORT_VERSION;
     }
 
+    @Override
+    public LearnToRankConfig rewrite(QueryRewriteContext ctx) throws IOException {
+        if (this.featureExtractorBuilders.isEmpty()) {
+            return this;
+        }
+        boolean rewritten = false;
+        List<LearnToRankFeatureExtractorBuilder> rewrittenExtractors = new ArrayList<>(this.featureExtractorBuilders.size());
+        for (LearnToRankFeatureExtractorBuilder extractorBuilder : this.featureExtractorBuilders) {
+            LearnToRankFeatureExtractorBuilder rewrittenExtractor = Rewriteable.rewrite(extractorBuilder, ctx);
+            rewrittenExtractors.add(rewrittenExtractor);
+            rewritten |= (rewrittenExtractor != extractorBuilder);
+        }
+        if (rewritten) {
+            return new LearnToRankConfig(getNumTopFeatureImportanceValues(), rewrittenExtractors);
+        }
+        return this;
+    }
+
     public static class Builder {
         private Integer numTopFeatureImportanceValues;
         private List<LearnToRankFeatureExtractorBuilder> learnToRankFeatureExtractorBuilders;

+ 21 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java

@@ -9,6 +9,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -30,7 +32,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceCo
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig.FEATURE_EXTRACTORS;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
 
-public class LearnToRankConfigUpdate implements InferenceConfigUpdate, NamedXContentObject {
+public class LearnToRankConfigUpdate implements InferenceConfigUpdate, NamedXContentObject, Rewriteable<LearnToRankConfigUpdate> {
 
     public static final ParseField NAME = LearnToRankConfig.NAME;
 
@@ -200,6 +202,24 @@ public class LearnToRankConfigUpdate implements InferenceConfigUpdate, NamedXCon
                 || Objects.equals(originalConfig.getFeatureExtractorBuilders(), featureExtractorBuilderList));
     }
 
+    @Override
+    public LearnToRankConfigUpdate rewrite(QueryRewriteContext ctx) throws IOException {
+        if (featureExtractorBuilderList.isEmpty()) {
+            return this;
+        }
+        List<LearnToRankFeatureExtractorBuilder> rewrittenBuilders = new ArrayList<>(featureExtractorBuilderList.size());
+        boolean rewritten = false;
+        for (LearnToRankFeatureExtractorBuilder extractorBuilder : featureExtractorBuilderList) {
+            LearnToRankFeatureExtractorBuilder rewrittenExtractor = Rewriteable.rewrite(extractorBuilder, ctx);
+            rewritten |= (rewrittenExtractor != extractorBuilder);
+            rewrittenBuilders.add(rewrittenExtractor);
+        }
+        if (rewritten) {
+            return new LearnToRankConfigUpdate(getNumTopFeatureImportanceValues(), rewrittenBuilders);
+        }
+        return this;
+    }
+
     public static class Builder implements InferenceConfigUpdate.Builder<Builder, LearnToRankConfigUpdate> {
         private Integer numTopFeatureImportanceValues;
         private List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilderList;

+ 6 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

@@ -145,7 +145,12 @@ public class EnsembleInferenceModel implements InferenceModel {
         if (preparedForInference == false) {
             throw ExceptionsHelper.serverError("model is not prepared for inference");
         }
-        LOGGER.debug(() -> "Inference called with feature names [" + Strings.arrayToCommaDelimitedString(featureNames) + "]");
+        LOGGER.debug(
+            () -> "Inference called with feature names ["
+                + Strings.arrayToCommaDelimitedString(featureNames)
+                + "] values "
+                + Arrays.toString(features)
+        );
         double[][] inferenceResults = new double[this.models.size()][];
         double[][] featureInfluence = new double[features.length][];
         int i = 0;

+ 11 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/LearnToRankFeatureExtractorBuilder.java

@@ -8,10 +8,15 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr;
 
 import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
-public interface LearnToRankFeatureExtractorBuilder extends NamedXContentObject, NamedWriteable {
+public interface LearnToRankFeatureExtractorBuilder
+    extends
+        NamedXContentObject,
+        NamedWriteable,
+        Rewriteable<LearnToRankFeatureExtractorBuilder> {
 
     ParseField FEATURE_NAME = new ParseField("feature_name");
 
@@ -19,4 +24,9 @@ public interface LearnToRankFeatureExtractorBuilder extends NamedXContentObject,
      * @return The input feature that this extractor satisfies
      */
     String featureName();
+
+    /**
+     * @throws Exception If the extractor is invalid.
+     */
+    void validate() throws Exception;
 }

+ 111 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/QueryExtractorBuilder.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
+
+import java.io.IOException;
+
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_CONFIG_QUERY_BAD_FORMAT;
+import static org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper.requireNonNull;
+
+public record QueryExtractorBuilder(String featureName, QueryProvider query) implements LearnToRankFeatureExtractorBuilder {
+
+    public static final ParseField NAME = new ParseField("query_extractor");
+    public static final ParseField FEATURE_NAME = new ParseField("feature_name");
+    public static final ParseField QUERY = new ParseField("query");
+
+    private static final ConstructingObjectParser<QueryExtractorBuilder, Void> PARSER = new ConstructingObjectParser<>(
+        NAME.getPreferredName(),
+        a -> new QueryExtractorBuilder((String) a[0], (QueryProvider) a[1])
+    );
+    private static final ConstructingObjectParser<QueryExtractorBuilder, Void> LENIENT_PARSER = new ConstructingObjectParser<>(
+        NAME.getPreferredName(),
+        true,
+        a -> new QueryExtractorBuilder((String) a[0], (QueryProvider) a[1])
+    );
+    static {
+        PARSER.declareString(constructorArg(), FEATURE_NAME);
+        PARSER.declareObject(constructorArg(), (p, c) -> QueryProvider.fromXContent(p, false, INFERENCE_CONFIG_QUERY_BAD_FORMAT), QUERY);
+        LENIENT_PARSER.declareString(constructorArg(), FEATURE_NAME);
+        LENIENT_PARSER.declareObject(
+            constructorArg(),
+            (p, c) -> QueryProvider.fromXContent(p, true, INFERENCE_CONFIG_QUERY_BAD_FORMAT),
+            QUERY
+        );
+    }
+
+    public static QueryExtractorBuilder fromXContent(XContentParser parser, Object context) {
+        boolean lenient = Boolean.TRUE.equals(context);
+        return lenient ? LENIENT_PARSER.apply(parser, null) : PARSER.apply(parser, null);
+    }
+
+    public QueryExtractorBuilder(String featureName, QueryProvider query) {
+        this.featureName = requireNonNull(featureName, FEATURE_NAME);
+        this.query = requireNonNull(query, QUERY);
+    }
+
+    public QueryExtractorBuilder(StreamInput input) throws IOException {
+        this(input.readString(), QueryProvider.fromStream(input));
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FEATURE_NAME.getPreferredName(), featureName);
+        builder.field(QUERY.getPreferredName(), query.getQuery());
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(featureName);
+        query.writeTo(out);
+    }
+
+    @Override
+    public String featureName() {
+        return featureName;
+    }
+
+    @Override
+    public void validate() throws Exception {
+        if (query.getParsingException() != null) {
+            throw query.getParsingException();
+        }
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public QueryExtractorBuilder rewrite(QueryRewriteContext ctx) throws IOException {
+        QueryProvider rewritten = Rewriteable.rewrite(query, ctx);
+        if (rewritten == query) {
+            return this;
+        }
+        return new QueryExtractorBuilder(featureName, rewritten);
+    }
+}

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -114,6 +114,8 @@ public final class Messages {
     public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
     public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
         "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
+    public static final String INFERENCE_CONFIG_QUERY_BAD_FORMAT = "Inference config query is not parsable";
+    public static final String INFERENCE_CONFIG_INCORRECT_TYPE = "Inference config of type [{0}] is invalid, must be of type [{1}]";
     public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
     public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
     public static final String VOCABULARY_NOT_FOUND = "Could not find vocabulary document [{1}] for trained model [{0}]";

+ 16 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java

@@ -14,6 +14,8 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -25,7 +27,7 @@ import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Objects;
 
-public class QueryProvider implements Writeable, ToXContentObject {
+public class QueryProvider implements Writeable, ToXContentObject, Rewriteable<QueryProvider> {
 
     private static final Logger logger = LogManager.getLogger(QueryProvider.class);
 
@@ -131,4 +133,17 @@ public class QueryProvider implements Writeable, ToXContentObject {
         builder.map(query);
         return builder;
     }
+
+    @Override
+    public QueryProvider rewrite(QueryRewriteContext ctx) throws IOException {
+        assert parsedQuery != null;
+        if (parsedQuery == null) {
+            return this;
+        }
+        QueryBuilder rewritten = Rewriteable.rewrite(parsedQuery, ctx);
+        if (rewritten == parsedQuery) {
+            return this;
+        }
+        return new QueryProvider(query, rewritten, parsingException);
+    }
 }

+ 13 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -22,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -43,7 +45,7 @@ public class LearnToRankConfigTests extends InferenceConfigItemTestCase<LearnToR
             randomBoolean() ? null : randomIntBetween(0, 10),
             randomBoolean()
                 ? null
-                : Stream.generate(() -> new TestValueExtractor(randomAlphaOfLength(10))).limit(randomInt(5)).collect(Collectors.toList())
+                : Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList())
         );
     }
 
@@ -117,6 +119,7 @@ public class LearnToRankConfigTests extends InferenceConfigItemTestCase<LearnToR
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
         List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
         namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(
@@ -128,7 +131,7 @@ public class LearnToRankConfigTests extends InferenceConfigItemTestCase<LearnToR
         return new NamedWriteableRegistry(namedWriteables);
     }
 
-    static class TestValueExtractor implements LearnToRankFeatureExtractorBuilder {
+    private static class TestValueExtractor implements LearnToRankFeatureExtractorBuilder {
         public static final ParseField NAME = new ParseField("test");
         private final String featureName;
 
@@ -182,6 +185,9 @@ public class LearnToRankConfigTests extends InferenceConfigItemTestCase<LearnToR
             return featureName;
         }
 
+        @Override
+        public void validate() throws Exception {}
+
         @Override
         public String getName() {
             return NAME.getPreferredName();
@@ -199,5 +205,10 @@ public class LearnToRankConfigTests extends InferenceConfigItemTestCase<LearnToR
         public int hashCode() {
             return Objects.hash(featureName);
         }
+
+        @Override
+        public TestValueExtractor rewrite(QueryRewriteContext ctx) throws IOException {
+            return this;
+        }
     }
 }

+ 22 - 18
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java

@@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
@@ -17,11 +18,16 @@ import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig;
 import static org.hamcrest.Matchers.equalTo;
@@ -31,10 +37,15 @@ import static org.hamcrest.Matchers.is;
 public class LearnToRankConfigUpdateTests extends AbstractBWCSerializationTestCase<LearnToRankConfigUpdate> {
 
     public static LearnToRankConfigUpdate randomLearnToRankConfigUpdate() {
-        return new LearnToRankConfigUpdate(randomBoolean() ? null : randomIntBetween(0, 10), null);
+        return new LearnToRankConfigUpdate(
+            randomBoolean() ? null : randomIntBetween(0, 10),
+            randomBoolean()
+                ? null
+                : Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList())
+        );
     }
 
-    public void testApply() {
+    public void testApply() throws IOException {
         LearnToRankConfig originalConfig = randomLearnToRankConfig();
         assertThat(originalConfig, equalTo(LearnToRankConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
         assertThat(
@@ -46,8 +57,14 @@ public class LearnToRankConfigUpdateTests extends AbstractBWCSerializationTestCa
             equalTo(new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(1).build().apply(originalConfig))
         );
 
-        LearnToRankFeatureExtractorBuilder extractorBuilder = new LearnToRankConfigTests.TestValueExtractor("foo");
-        LearnToRankFeatureExtractorBuilder extractorBuilder2 = new LearnToRankConfigTests.TestValueExtractor("bar");
+        LearnToRankFeatureExtractorBuilder extractorBuilder = new QueryExtractorBuilder(
+            "foo",
+            QueryProvider.fromParsedQuery(QueryBuilders.termQuery("foo", "bar"))
+        );
+        LearnToRankFeatureExtractorBuilder extractorBuilder2 = new QueryExtractorBuilder(
+            "bar",
+            QueryProvider.fromParsedQuery(QueryBuilders.termQuery("foo", "bar"))
+        );
 
         LearnToRankConfig config = new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(1)
             .setFeatureExtractorBuilders(List.of(extractorBuilder2, extractorBuilder))
@@ -89,13 +106,6 @@ public class LearnToRankConfigUpdateTests extends AbstractBWCSerializationTestCa
         namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
         namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
         namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
-        namedXContent.add(
-            new NamedXContentRegistry.Entry(
-                LearnToRankFeatureExtractorBuilder.class,
-                LearnToRankConfigTests.TestValueExtractor.NAME,
-                LearnToRankConfigTests.TestValueExtractor::fromXContent
-            )
-        );
         return new NamedXContentRegistry(namedXContent);
     }
 
@@ -103,13 +113,7 @@ public class LearnToRankConfigUpdateTests extends AbstractBWCSerializationTestCa
     protected NamedWriteableRegistry writableRegistry() {
         List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
         namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
-        namedWriteables.add(
-            new NamedWriteableRegistry.Entry(
-                LearnToRankFeatureExtractorBuilder.class,
-                LearnToRankConfigTests.TestValueExtractor.NAME.getPreferredName(),
-                LearnToRankConfigTests.TestValueExtractor::new
-            )
-        );
+        namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
         return new NamedWriteableRegistry(namedWriteables);
     }
 

+ 81 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/QueryExtractorBuilderTests.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractXContentSerializingTestCase;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.QueryProviderTests;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.function.Predicate;
+
+public class QueryExtractorBuilderTests extends AbstractXContentSerializingTestCase<QueryExtractorBuilder> {
+
+    protected boolean lenient;
+
+    public static QueryExtractorBuilder randomInstance() {
+        return new QueryExtractorBuilder(randomAlphaOfLength(10), QueryProviderTests.createRandomValidQueryProvider());
+    }
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> field.isEmpty() == false;
+    }
+
+    @Override
+    protected Writeable.Reader<QueryExtractorBuilder> instanceReader() {
+        return QueryExtractorBuilder::new;
+    }
+
+    @Override
+    protected QueryExtractorBuilder createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected QueryExtractorBuilder mutateInstance(QueryExtractorBuilder instance) throws IOException {
+        int i = randomInt(1);
+        return switch (i) {
+            case 0 -> new QueryExtractorBuilder(randomAlphaOfLength(10), instance.query());
+            case 1 -> new QueryExtractorBuilder(instance.featureName(), QueryProviderTests.createRandomValidQueryProvider());
+            default -> throw new AssertionError("unknown random case for instance mutation");
+        };
+    }
+
+    @Override
+    protected QueryExtractorBuilder doParseInstance(XContentParser parser) throws IOException {
+        return QueryExtractorBuilder.fromXContent(parser, lenient);
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
+    }
+}

+ 181 - 87
x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java

@@ -29,90 +29,152 @@ public class MlRescorerIT extends ESRestTestCase {
     @Before
     public void setupModelAndData() throws IOException {
         putRegressionModel(MODEL_ID, """
-            {
-                        "description": "super complex model for tests",
-                        "input": {"field_names": ["cost", "product"]},
-                        "inference_config": {
-                          "learn_to_rank": {
-                          }
+             {
+              "description": "super complex model for tests",
+              "input": {"field_names": ["cost", "product"]},
+              "inference_config": {
+                "learn_to_rank": {
+                  "feature_extractors": [{
+                    "query_extractor": {
+                      "feature_name": "two",
+                      "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}}
+                    }
+                  }]
+                }
+              },
+              "definition": {
+                "preprocessors" : [{
+                  "one_hot_encoding": {
+                    "field": "product",
+                    "hot_map": {
+                      "TV": "type_tv",
+                      "VCR": "type_vcr",
+                      "Laptop": "type_laptop"
+                    }
+                  }
+                }],
+                "trained_model": {
+                  "ensemble": {
+                    "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"],
+                    "target_type": "regression",
+                    "trained_models": [
+                    {
+                      "tree": {
+                        "feature_names": [
+                          "cost"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 400,
+                          "decision_type": "lte",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
                         },
-                        "definition": {
-                          "preprocessors" : [{
-                            "one_hot_encoding": {
-                              "field": "product",
-                              "hot_map": {
-                                "TV": "type_tv",
-                                "VCR": "type_vcr",
-                                "Laptop": "type_laptop"
-                              }
-                            }
-                          }],
-                          "trained_model": {
-                            "ensemble": {
-                              "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop"],
-                              "target_type": "regression",
-                              "trained_models": [
-                              {
-                                "tree": {
-                                  "feature_names": [
-                                    "cost"
-                                  ],
-                                  "tree_structure": [
-                                  {
-                                    "node_index": 0,
-                                    "split_feature": 0,
-                                    "split_gain": 12,
-                                    "threshold": 400,
-                                    "decision_type": "lte",
-                                    "default_left": true,
-                                    "left_child": 1,
-                                    "right_child": 2
-                                  },
-                                  {
-                                    "node_index": 1,
-                                    "leaf_value": 5.0
-                                  },
-                                  {
-                                    "node_index": 2,
-                                    "leaf_value": 2.0
-                                  }
-                                  ],
-                                  "target_type": "regression"
-                                }
-                              },
-                              {
-                                "tree": {
-                                  "feature_names": [
-                                    "type_tv"
-                                  ],
-                                  "tree_structure": [
-                                  {
-                                    "node_index": 0,
-                                    "split_feature": 0,
-                                    "split_gain": 12,
-                                    "threshold": 1,
-                                    "decision_type": "lt",
-                                    "default_left": true,
-                                    "left_child": 1,
-                                    "right_child": 2
-                                  },
-                                  {
-                                    "node_index": 1,
-                                    "leaf_value": 1.0
-                                  },
-                                  {
-                                    "node_index": 2,
-                                    "leaf_value": 12.0
-                                  }
-                                  ],
-                                  "target_type": "regression"
-                                }
-                              }
-                              ]
-                            }
-                          }
+                        {
+                          "node_index": 1,
+                          "leaf_value": 5.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 2.0
                         }
-                      }""");
+                        ],
+                        "target_type": "regression"
+                      }
+                    },
+                    {
+                      "tree": {
+                        "feature_names": [
+                          "type_tv"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 1,
+                          "decision_type": "lt",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 1.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 12.0
+                        }
+                        ],
+                        "target_type": "regression"
+                      }
+                    },
+                     {
+                      "tree": {
+                        "feature_names": [
+                          "two"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 1,
+                          "decision_type": "lt",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 1.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 2.0
+                        }
+                        ],
+                        "target_type": "regression"
+                      }
+                    },
+                     {
+                      "tree": {
+                        "feature_names": [
+                          "product_bm25"
+                        ],
+                        "tree_structure": [
+                        {
+                          "node_index": 0,
+                          "split_feature": 0,
+                          "split_gain": 12,
+                          "threshold": 1,
+                          "decision_type": "lt",
+                          "default_left": true,
+                          "left_child": 1,
+                          "right_child": 2
+                        },
+                        {
+                          "node_index": 1,
+                          "leaf_value": 1.0
+                        },
+                        {
+                          "node_index": 2,
+                          "leaf_value": 4.0
+                        }
+                        ],
+                        "target_type": "regression"
+                      }
+                    }
+                    ]
+                  }
+                }
+              }
+            }""");
         createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", randomIntBetween(1, 3)).build(), """
             "properties":{
              "product":{"type": "keyword"},
@@ -142,7 +204,7 @@ public class MlRescorerIT extends ESRestTestCase {
             }""");
 
         Map<String, Object> response = responseAsMap(searchResponse);
-        assertThat((List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(17.0, 17.0));
+        assertThat((List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));
     }
 
     @SuppressWarnings("unchecked")
@@ -155,14 +217,46 @@ public class MlRescorerIT extends ESRestTestCase {
             "rescore": {
                     "window_size": 10,
                     "inference": {
-                        "model_id": "basic-ltr-model"
+                        "model_id": "basic-ltr-model",
+                        "inference_config": {
+                          "learn_to_rank": {
+                            "feature_extractors":[
+                              {"query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "TV"}}}}
+                            ]
+                          }
                         }
+                      }
                 }
 
             }""");
 
         Map<String, Object> response = responseAsMap(searchResponse);
-        assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(17.0, 17.0));
+        assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));
+
+        searchResponse = searchDfs("""
+            {
+            "rescore": {
+                    "window_size": 10,
+                    "inference": {
+                        "model_id": "basic-ltr-model",
+                        "inference_config": {
+                          "learn_to_rank": {
+                            "feature_extractors":[
+                              {"query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "TV"}}}}
+                              ]
+                            }
+                          }
+                        }
+                }
+
+            }""");
+
+        response = responseAsMap(searchResponse);
+        assertThat(
+            response.toString(),
+            (List<Double>) XContentMapValues.extractValue("hits.hits._score", response),
+            contains(20.0, 20.0, 9.0, 9.0, 6.0)
+        );
     }
 
     @SuppressWarnings("unchecked")
@@ -219,7 +313,7 @@ public class MlRescorerIT extends ESRestTestCase {
             }""", false);
 
         Map<String, Object> response = responseAsMap(searchResponse);
-        assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(17.0, 17.0));
+        assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));
 
         searchResponse = searchCanMatch("""
             { "query": {
@@ -235,7 +329,7 @@ public class MlRescorerIT extends ESRestTestCase {
             }""", true);
 
         response = responseAsMap(searchResponse);
-        assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(17.0, 17.0));
+        assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));
     }
 
     private void indexData(String data) throws IOException {

+ 169 - 89
x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java

@@ -27,89 +27,151 @@ public class InferenceRescorerIT extends InferenceTestCase {
     public void setupModelAndData() throws IOException {
         putRegressionModel(MODEL_ID, """
             {
-                        "description": "super complex model for tests",
-                        "input": {"field_names": ["cost", "product"]},
-                        "inference_config": {
-                          "learn_to_rank": {
-                          }
-                        },
-                        "definition": {
-                          "preprocessors" : [{
-                            "one_hot_encoding": {
-                              "field": "product",
-                              "hot_map": {
-                                "TV": "type_tv",
-                                "VCR": "type_vcr",
-                                "Laptop": "type_laptop"
-                              }
-                            }
-                          }],
-                          "trained_model": {
-                            "ensemble": {
-                              "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop"],
-                              "target_type": "regression",
-                              "trained_models": [
-                                {
-                                  "tree": {
-                                    "feature_names": [
-                                      "cost"
-                                    ],
-                                    "tree_structure": [
-                                      {
-                                        "node_index": 0,
-                                        "split_feature": 0,
-                                        "split_gain": 12,
-                                        "threshold": 400,
-                                        "decision_type": "lte",
-                                        "default_left": true,
-                                        "left_child": 1,
-                                        "right_child": 2
-                                      },
-                                      {
-                                        "node_index": 1,
-                                        "leaf_value": 5.0
-                                      },
-                                      {
-                                        "node_index": 2,
-                                        "leaf_value": 2.0
-                                      }
-                                    ],
-                                    "target_type": "regression"
-                                  }
-                                },
-                                {
-                                  "tree": {
-                                    "feature_names": [
-                                      "type_tv"
-                                    ],
-                                    "tree_structure": [
-                                      {
-                                        "node_index": 0,
-                                        "split_feature": 0,
-                                        "split_gain": 12,
-                                        "threshold": 1,
-                                        "decision_type": "lt",
-                                        "default_left": true,
-                                        "left_child": 1,
-                                        "right_child": 2
-                                      },
-                                      {
-                                        "node_index": 1,
-                                        "leaf_value": 1.0
-                                      },
-                                      {
-                                        "node_index": 2,
-                                        "leaf_value": 12.0
-                                      }
-                                    ],
-                                    "target_type": "regression"
-                                  }
-                                }
-                              ]
-                            }
-                          }
-                        }
-                      }""");
+               "description": "super complex model for tests",
+               "input": {"field_names": ["cost", "product"]},
+               "inference_config": {
+                 "learn_to_rank": {
+                   "feature_extractors": [{
+                     "query_extractor": {
+                       "feature_name": "two",
+                       "query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}}
+                     }
+                   }]
+                 }
+               },
+               "definition": {
+                 "preprocessors" : [{
+                   "one_hot_encoding": {
+                     "field": "product",
+                     "hot_map": {
+                       "TV": "type_tv",
+                       "VCR": "type_vcr",
+                       "Laptop": "type_laptop"
+                     }
+                   }
+                 }],
+                 "trained_model": {
+                   "ensemble": {
+                     "feature_names": ["cost", "type_tv", "type_vcr", "type_laptop", "two", "product_bm25"],
+                     "target_type": "regression",
+                     "trained_models": [
+                       {
+                         "tree": {
+                           "feature_names": [
+                             "cost"
+                           ],
+                           "tree_structure": [
+                             {
+                               "node_index": 0,
+                               "split_feature": 0,
+                               "split_gain": 12,
+                               "threshold": 400,
+                               "decision_type": "lte",
+                               "default_left": true,
+                               "left_child": 1,
+                               "right_child": 2
+                             },
+                             {
+                               "node_index": 1,
+                               "leaf_value": 5.0
+                             },
+                             {
+                               "node_index": 2,
+                               "leaf_value": 2.0
+                             }
+                           ],
+                           "target_type": "regression"
+                         }
+                       },
+                       {
+                         "tree": {
+                           "feature_names": [
+                             "type_tv"
+                           ],
+                           "tree_structure": [
+                             {
+                               "node_index": 0,
+                               "split_feature": 0,
+                               "split_gain": 12,
+                               "threshold": 1,
+                               "decision_type": "lt",
+                               "default_left": true,
+                               "left_child": 1,
+                               "right_child": 2
+                             },
+                             {
+                               "node_index": 1,
+                               "leaf_value": 1.0
+                             },
+                             {
+                               "node_index": 2,
+                               "leaf_value": 12.0
+                             }
+                           ],
+                           "target_type": "regression"
+                         }
+                       },
+                       {
+                         "tree": {
+                           "feature_names": [
+                             "two"
+                           ],
+                           "tree_structure": [
+                             {
+                               "node_index": 0,
+                               "split_feature": 0,
+                               "split_gain": 12,
+                               "threshold": 1,
+                               "decision_type": "lt",
+                               "default_left": true,
+                               "left_child": 1,
+                               "right_child": 2
+                             },
+                             {
+                               "node_index": 1,
+                               "leaf_value": 1.0
+                             },
+                             {
+                               "node_index": 2,
+                               "leaf_value": 2.0
+                             }
+                           ],
+                           "target_type": "regression"
+                         }
+                       },
+                       {
+                         "tree": {
+                           "feature_names": [
+                             "product_bm25"
+                           ],
+                           "tree_structure": [
+                             {
+                               "node_index": 0,
+                               "split_feature": 0,
+                               "split_gain": 12,
+                               "threshold": 1,
+                               "decision_type": "lt",
+                               "default_left": true,
+                               "left_child": 1,
+                               "right_child": 2
+                             },
+                             {
+                               "node_index": 1,
+                               "leaf_value": 1.0
+                             },
+                             {
+                               "node_index": 2,
+                               "leaf_value": 4.0
+                             }
+                           ],
+                           "target_type": "regression"
+                         }
+                       }
+                     ]
+                   }
+                 }
+               }
+             }""");
         createIndex(INDEX_NAME, Settings.EMPTY, """
             "properties":{
              "product":{"type": "keyword"},
@@ -127,7 +189,7 @@ public class InferenceRescorerIT extends InferenceTestCase {
     }
 
     public void testInferenceRescore() throws Exception {
-        Request request = new Request("GET", "store/_search?size=3");
+        Request request = new Request("GET", "store/_search?size=3&error_trace");
         request.setJsonEntity("""
             {
               "rescore": {
@@ -135,16 +197,34 @@ public class InferenceRescorerIT extends InferenceTestCase {
                 "inference": { "model_id": "ltr-model" }
               }
             }""");
-        assertHitScores(client().performRequest(request), List.of(17.0, 17.0, 14.0));
+        assertHitScores(client().performRequest(request), List.of(20.0, 20.0, 17.0));
         request.setJsonEntity("""
             {
               "query": {"term": {"product": "Laptop"}},
               "rescore": {
                 "window_size": 10,
-                "inference": { "model_id": "ltr-model" }
+                "inference": {
+                  "model_id": "ltr-model",
+                  "inference_config": {
+                    "learn_to_rank": {
+                      "feature_extractors":[{
+                        "query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "Laptop"}}}
+                      }]
+                    }
+                  }
+                }
+              }
+            }""");
+        assertHitScores(client().performRequest(request), List.of(12.0, 12.0, 9.0));
+        request.setJsonEntity("""
+            {
+              "query": {"term": {"product": "Laptop"}},
+              "rescore": {
+                "window_size": 10,
+                "inference": { "model_id": "ltr-model"}
               }
             }""");
-        assertHitScores(client().performRequest(request), List.of(6.0, 6.0, 3.0));
+        assertHitScores(client().performRequest(request), List.of(9.0, 9.0, 6.0));
     }
 
     public void testInferenceRescoreSmallWindow() throws Exception {
@@ -156,7 +236,7 @@ public class InferenceRescorerIT extends InferenceTestCase {
                 "inference": { "model_id": "ltr-model" }
               }
             }""");
-        assertHitScores(client().performRequest(request), List.of(17.0, 17.0, 1.0, 1.0, 1.0));
+        assertHitScores(client().performRequest(request), List.of(20.0, 20.0, 1.0, 1.0, 1.0));
     }
 
     public void testInferenceRescorerWithChainedRescorers() throws IOException {
@@ -178,7 +258,7 @@ public class InferenceRescorerIT extends InferenceTestCase {
                }
               ]
              }""");
-        assertHitScores(client().performRequest(request), List.of(37.0, 37.0, 14.0, 5.0, 1.0));
+        assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 17.0, 5.0, 1.0));
     }
 
     private void indexData(String data) throws IOException {

+ 19 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

@@ -199,6 +199,25 @@ public class LocalModel implements Closeable {
         return result.get();
     }
 
+    public InferenceResults inferLtr(Map<String, Object> fields, InferenceConfig config) {
+        statsAccumulator.incInference();
+        currentInferenceCount.increment();
+
+        // We should never have nested maps in a LTR context as we retrieve values from source value extractor, queries, or doc_values
+        assert fields.values().stream().noneMatch(o -> o instanceof Map<?, ?>);
+        // might resolve fields to their appropriate name
+        LocalModel.mapFieldsIfNecessary(fields, defaultFieldMap);
+        boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
+        if (fields.isEmpty()) {
+            statsAccumulator.incMissingFields();
+        }
+        InferenceResults inferenceResults = trainedModelDefinition.infer(fields, config);
+        if (shouldPersistStats) {
+            persistStats(false);
+        }
+        return inferenceResults;
+    }
+
     /**
      * Used for translating field names in according to the passed `fieldMappings` parameter.
      *

+ 13 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java

@@ -15,9 +15,11 @@ import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.common.util.Maps;
+import org.elasticsearch.core.Strings;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.rescore.Rescorer;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 
 import java.io.IOException;
@@ -69,7 +71,7 @@ public class InferenceRescorer implements Rescorer {
         List<LeafReaderContext> leaves = ltrRescoreContext.executionContext.searcher().getIndexReader().leaves();
         LeafReaderContext currentSegment = null;
         boolean changedSegment = true;
-        List<FeatureExtractor> featureExtractors = ltrRescoreContext.buildFeatureExtractors();
+        List<FeatureExtractor> featureExtractors = ltrRescoreContext.buildFeatureExtractors(searcher);
         List<Map<String, Object>> docFeatures = new ArrayList<>(topNDocIDs.size());
         int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
         while (hitUpto < hitsToRescore.length) {
@@ -95,14 +97,21 @@ public class InferenceRescorer implements Rescorer {
             for (FeatureExtractor featureExtractor : featureExtractors) {
                 featureExtractor.addFeatures(features, targetDoc);
             }
+            logger.debug(() -> Strings.format("doc [%d] has features [%s]", targetDoc, features));
             docFeatures.add(features);
             hitUpto++;
         }
         for (int i = 0; i < hitsToRescore.length; i++) {
             Map<String, Object> features = docFeatures.get(i);
             try {
-                hitsToRescore[i].score = ((Number) definition.infer(features, LearnToRankConfigUpdate.EMPTY_PARAMS).predictedValue())
-                    .floatValue();
+                InferenceResults results = definition.inferLtr(features, ltrRescoreContext.inferenceConfig);
+                if (results instanceof WarningInferenceResults warningInferenceResults) {
+                    logger.warn("Failure rescoring doc, warning returned [" + warningInferenceResults.getWarning() + "]");
+                } else if (results.predictedValue() instanceof Number prediction) {
+                    hitsToRescore[i].score = prediction.floatValue();
+                } else {
+                    logger.warn("Failure rescoring doc, unexpected inference result of kind [" + results.getWriteableName() + "]");
+                }
             } catch (Exception ex) {
                 logger.warn("Failure rescoring doc...", ex);
             }

+ 295 - 34
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java

@@ -10,29 +10,56 @@ package org.elasticsearch.xpack.ml.inference.rescorer;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ClientHelper;
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 
 import java.io.IOException;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.function.Supplier;
 
 public class InferenceRescorerBuilder extends RescorerBuilder<InferenceRescorerBuilder> {
 
     public static final String NAME = "inference";
     private static final ParseField MODEL = new ParseField("model_id");
+    private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
+    private static final ParseField INTERNAL_INFERENCE_CONFIG = new ParseField("_internal_inference_config");
     private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, false, Builder::new);
     static {
         PARSER.declareString(Builder::setModelId, MODEL);
+        PARSER.declareNamedObject(
+            Builder::setInferenceConfigUpdate,
+            (p, c, name) -> p.namedObject(InferenceConfigUpdate.class, name, false),
+            INFERENCE_CONFIG
+        );
+        PARSER.declareNamedObject(
+            Builder::setInferenceConfig,
+            (p, c, name) -> p.namedObject(StrictlyParsedInferenceConfig.class, name, false),
+            INTERNAL_INFERENCE_CONFIG
+        );
     }
 
     public static InferenceRescorerBuilder fromXContent(XContentParser parser, Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
@@ -40,41 +67,86 @@ public class InferenceRescorerBuilder extends RescorerBuilder<InferenceRescorerB
     }
 
     private final String modelId;
+    private final LearnToRankConfigUpdate inferenceConfigUpdate;
+    private final LearnToRankConfig inferenceConfig;
     private final LocalModel inferenceDefinition;
     private final Supplier<LocalModel> inferenceDefinitionSupplier;
     private final Supplier<ModelLoadingService> modelLoadingServiceSupplier;
+    private final Supplier<LearnToRankConfig> inferenceConfigSupplier;
     private boolean rescoreOccurred;
 
-    public InferenceRescorerBuilder(String modelId, Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
+    public InferenceRescorerBuilder(
+        String modelId,
+        LearnToRankConfigUpdate inferenceConfigUpdate,
+        Supplier<ModelLoadingService> modelLoadingServiceSupplier
+    ) {
         this.modelId = Objects.requireNonNull(modelId);
+        this.inferenceConfigUpdate = inferenceConfigUpdate;
         this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
         this.inferenceDefinition = null;
         this.inferenceDefinitionSupplier = null;
+        this.inferenceConfigSupplier = null;
+        this.inferenceConfig = null;
     }
 
-    InferenceRescorerBuilder(String modelId, LocalModel inferenceDefinition) {
+    InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
         this.modelId = Objects.requireNonNull(modelId);
-        this.inferenceDefinition = Objects.requireNonNull(inferenceDefinition);
+        this.inferenceConfigUpdate = null;
+        this.inferenceDefinition = null;
         this.inferenceDefinitionSupplier = null;
-        this.modelLoadingServiceSupplier = null;
+        this.inferenceConfigSupplier = null;
+        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
+        this.inferenceConfig = Objects.requireNonNull(inferenceConfig);
     }
 
     private InferenceRescorerBuilder(
         String modelId,
+        LearnToRankConfigUpdate update,
+        Supplier<ModelLoadingService> modelLoadingServiceSupplier,
+        Supplier<LearnToRankConfig> inferenceConfigSupplier
+    ) {
+        this.modelId = Objects.requireNonNull(modelId);
+        this.inferenceConfigUpdate = update;
+        this.inferenceDefinition = null;
+        this.inferenceDefinitionSupplier = null;
+        this.inferenceConfigSupplier = inferenceConfigSupplier;
+        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
+        this.inferenceConfig = null;
+    }
+
+    private InferenceRescorerBuilder(
+        String modelId,
+        LearnToRankConfig inferenceConfig,
         Supplier<ModelLoadingService> modelLoadingServiceSupplier,
         Supplier<LocalModel> inferenceDefinitionSupplier
     ) {
         this.modelId = modelId;
+        this.inferenceConfigUpdate = null;
         this.inferenceDefinition = null;
         this.inferenceDefinitionSupplier = inferenceDefinitionSupplier;
         this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
+        this.inferenceConfigSupplier = null;
+        this.inferenceConfig = inferenceConfig;
+    }
+
+    InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, LocalModel inferenceDefinition) {
+        this.modelId = modelId;
+        this.inferenceConfigUpdate = null;
+        this.inferenceDefinition = inferenceDefinition;
+        this.inferenceDefinitionSupplier = null;
+        this.modelLoadingServiceSupplier = null;
+        this.inferenceConfigSupplier = null;
+        this.inferenceConfig = inferenceConfig;
     }
 
     public InferenceRescorerBuilder(StreamInput input, Supplier<ModelLoadingService> modelLoadingServiceSupplier) throws IOException {
         super(input);
         this.modelId = input.readString();
+        this.inferenceConfigUpdate = (LearnToRankConfigUpdate) input.readOptionalNamedWriteable(InferenceConfigUpdate.class);
         this.inferenceDefinitionSupplier = null;
+        this.inferenceConfigSupplier = null;
         this.inferenceDefinition = null;
+        this.inferenceConfig = (LearnToRankConfig) input.readOptionalNamedWriteable(InferenceConfig.class);
         this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
     }
 
@@ -92,69 +164,199 @@ public class InferenceRescorerBuilder extends RescorerBuilder<InferenceRescorerB
         return TransportVersion.current();
     }
 
-    @Override
-    public RescorerBuilder<InferenceRescorerBuilder> rewrite(QueryRewriteContext ctx) throws IOException {
-        if (inferenceDefinition != null) {
+    /**
+     * Here we fetch the stored model inference context, apply the given update, and rewrite.
+     *
+     * This can and be done on the coordinator as it not only validates if the stored model is of the appropriate type, it allows
+     * any stored logic to rewrite on the coordinator level if possible.
+     * @param ctx QueryRewriteContext
+     * @return rewritten InferenceRescorerBuilder or self if no changes
+     * @throws IOException when rewrite fails
+     */
+    private RescorerBuilder<InferenceRescorerBuilder> doRewrite(QueryRewriteContext ctx) throws IOException {
+        // Awaiting fetch
+        if (inferenceConfigSupplier != null && inferenceConfigSupplier.get() == null) {
             return this;
         }
-        if (inferenceDefinitionSupplier != null) {
-            if (inferenceDefinitionSupplier.get() == null) {
+        if (inferenceConfig != null) {
+            LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(inferenceConfig, ctx);
+            if (rewrittenConfig == inferenceConfig) {
                 return this;
             }
-            LocalModel inferenceDefinition = inferenceDefinitionSupplier.get();
-            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, inferenceDefinition);
-            if (windowSize() != null) {
-                builder.windowSize(windowSize());
+            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, rewrittenConfig, modelLoadingServiceSupplier);
+            if (windowSize != null) {
+                builder.windowSize(windowSize);
             }
             return builder;
         }
-        // We don't want to rewrite on the coordinator as that doesn't make sense for this rescorer
-        if (ctx.convertToDataRewriteContext() != null) {
-            if (modelLoadingServiceSupplier == null || modelLoadingServiceSupplier.get() == null) {
-                throw new IllegalStateException("Model loading service must be available");
+        // We have requested for the stored config and fetch is completed, get the config and rewrite further if required
+        if (inferenceConfigSupplier != null) {
+            LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(inferenceConfigSupplier.get(), ctx);
+            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, rewrittenConfig, modelLoadingServiceSupplier);
+            if (windowSize != null) {
+                builder.windowSize(windowSize);
             }
-            SetOnce<LocalModel> inferenceDefinitionSetOnce = new SetOnce<>();
-            ctx.registerAsyncAction((c, l) -> modelLoadingServiceSupplier.get().getModelForLearnToRank(modelId, ActionListener.wrap(lm -> {
-                inferenceDefinitionSetOnce.set(lm);
-                l.onResponse(null);
-            }, l::onFailure)));
-            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(
-                modelId,
-                modelLoadingServiceSupplier,
-                inferenceDefinitionSetOnce::get
-            );
+            return builder;
+        }
+        SetOnce<LearnToRankConfig> configSetOnce = new SetOnce<>();
+        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId);
+        request.setAllowNoResources(false);
+        ctx.registerAsyncAction(
+            (c, l) -> ClientHelper.executeAsyncWithOrigin(
+                c,
+                ClientHelper.ML_ORIGIN,
+                GetTrainedModelsAction.INSTANCE,
+                request,
+                ActionListener.wrap(trainedModels -> {
+                    TrainedModelConfig config = trainedModels.getResources().results().get(0);
+                    if (config.getInferenceConfig() instanceof LearnToRankConfig retrievedInferenceConfig) {
+                        retrievedInferenceConfig = inferenceConfigUpdate == null
+                            ? retrievedInferenceConfig
+                            : inferenceConfigUpdate.apply(retrievedInferenceConfig);
+                        for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) {
+                            builder.validate();
+                        }
+                        configSetOnce.set(retrievedInferenceConfig);
+                        l.onResponse(null);
+                        return;
+                    }
+                    l.onFailure(
+                        ExceptionsHelper.badRequestException(
+                            Messages.getMessage(
+                                Messages.INFERENCE_CONFIG_INCORRECT_TYPE,
+                                Optional.ofNullable(config.getInferenceConfig()).map(InferenceConfig::getName).orElse("null"),
+                                LearnToRankConfig.NAME.getPreferredName()
+                            )
+                        )
+                    );
+                }, l::onFailure)
+            )
+        );
+        InferenceRescorerBuilder builder = new InferenceRescorerBuilder(
+            modelId,
+            inferenceConfigUpdate,
+            modelLoadingServiceSupplier,
+            configSetOnce::get
+        );
+        if (windowSize() != null) {
+            builder.windowSize(windowSize);
+        }
+        return builder;
+    }
+
+    /**
+     * This rewrite phase occurs on the data node when we know we will want to use the model for inference
+     * @param ctx Rewrite context
+     * @return A rewritten rescorer with a model definition or a model definition supplier populated
+     */
+    private RescorerBuilder<InferenceRescorerBuilder> doDataNodeRewrite(QueryRewriteContext ctx) {
+        assert inferenceConfig != null;
+        // We already have an inference definition, no need to do any rewriting
+        if (inferenceDefinition != null) {
+            return this;
+        }
+        // Awaiting fetch
+        if (inferenceDefinitionSupplier != null && inferenceDefinitionSupplier.get() == null) {
+            return this;
+        }
+        if (inferenceDefinitionSupplier != null) {
+            LocalModel inferenceDefinition = inferenceDefinitionSupplier.get();
+            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(modelId, inferenceConfig, inferenceDefinition);
             if (windowSize() != null) {
                 builder.windowSize(windowSize());
             }
             return builder;
         }
-        return this;
+        if (modelLoadingServiceSupplier == null || modelLoadingServiceSupplier.get() == null) {
+            throw new IllegalStateException("Model loading service must be available");
+        }
+        SetOnce<LocalModel> inferenceDefinitionSetOnce = new SetOnce<>();
+        ctx.registerAsyncAction((c, l) -> modelLoadingServiceSupplier.get().getModelForLearnToRank(modelId, ActionListener.wrap(lm -> {
+            inferenceDefinitionSetOnce.set(lm);
+            l.onResponse(null);
+        }, l::onFailure)));
+        InferenceRescorerBuilder builder = new InferenceRescorerBuilder(
+            modelId,
+            inferenceConfig,
+            modelLoadingServiceSupplier,
+            inferenceDefinitionSetOnce::get
+        );
+        if (windowSize() != null) {
+            builder.windowSize(windowSize());
+        }
+        return builder;
+    }
+
+    /**
+     * This rewrite phase occurs on the data node when we know we will want to use the model for inference
+     * @param ctx Rewrite context
+     * @return A rewritten rescorer with a model definition or a model definition supplier populated
+     * @throws IOException If fetching, parsing, or overall rewrite failures occur
+     */
+    private RescorerBuilder<InferenceRescorerBuilder> doSearchRewrite(QueryRewriteContext ctx) throws IOException {
+        if (inferenceConfig == null) {
+            return this;
+        }
+        LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(inferenceConfig, ctx);
+        if (rewrittenConfig == inferenceConfig) {
+            return this;
+        }
+        InferenceRescorerBuilder builder = inferenceDefinition == null
+            ? new InferenceRescorerBuilder(modelId, rewrittenConfig, modelLoadingServiceSupplier)
+            : new InferenceRescorerBuilder(modelId, rewrittenConfig, inferenceDefinition);
+        if (windowSize != null) {
+            builder.windowSize(windowSize);
+        }
+        return builder;
+    }
+
+    @Override
+    public RescorerBuilder<InferenceRescorerBuilder> rewrite(QueryRewriteContext ctx) throws IOException {
+        if (ctx.convertToDataRewriteContext() != null) {
+            return doDataNodeRewrite(ctx);
+        }
+        if (ctx.convertToSearchExecutionContext() != null) {
+            return doSearchRewrite(ctx);
+        }
+        return doRewrite(ctx);
     }
 
     public String getModelId() {
         return modelId;
     }
 
+    LearnToRankConfig getInferenceConfig() {
+        return inferenceConfig;
+    }
+
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
-        if (inferenceDefinitionSupplier != null) {
-            throw new IllegalStateException("supplier must be null, missing a rewriteAndFetch?");
+        if (inferenceDefinitionSupplier != null || inferenceConfigSupplier != null) {
+            throw new IllegalStateException("suppliers must be null, missing a rewriteAndFetch?");
         }
         assert inferenceDefinition == null || rescoreOccurred : "Unnecessarily populated local model object";
         out.writeString(modelId);
+        out.writeOptionalNamedWriteable(inferenceConfigUpdate);
+        out.writeOptionalNamedWriteable(inferenceConfig);
     }
 
     @Override
     protected void doXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject(NAME);
         builder.field(MODEL.getPreferredName(), modelId);
+        if (inferenceConfigUpdate != null) {
+            NamedXContentObjectHelper.writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfigUpdate);
+        }
+        if (inferenceConfig != null) {
+            NamedXContentObjectHelper.writeNamedObject(builder, params, INTERNAL_INFERENCE_CONFIG.getPreferredName(), inferenceConfig);
+        }
         builder.endObject();
     }
 
     @Override
     protected InferenceRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) {
         rescoreOccurred = true;
-        return new InferenceRescorerContext(windowSize, InferenceRescorer.INSTANCE, inferenceDefinition, context);
+        return new InferenceRescorerContext(windowSize, InferenceRescorer.INSTANCE, inferenceConfig, inferenceDefinition, context);
     }
 
     @Override
@@ -165,24 +367,83 @@ public class InferenceRescorerBuilder extends RescorerBuilder<InferenceRescorerB
         InferenceRescorerBuilder that = (InferenceRescorerBuilder) o;
         return Objects.equals(modelId, that.modelId)
             && Objects.equals(inferenceDefinition, that.inferenceDefinition)
+            && Objects.equals(inferenceConfigUpdate, that.inferenceConfigUpdate)
+            && Objects.equals(inferenceConfig, that.inferenceConfig)
             && Objects.equals(inferenceDefinitionSupplier, that.inferenceDefinitionSupplier)
             && Objects.equals(modelLoadingServiceSupplier, that.modelLoadingServiceSupplier);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), modelId, inferenceDefinition, inferenceDefinitionSupplier, modelLoadingServiceSupplier);
+        return Objects.hash(
+            super.hashCode(),
+            modelId,
+            inferenceConfigUpdate,
+            inferenceConfig,
+            inferenceDefinition,
+            inferenceDefinitionSupplier,
+            modelLoadingServiceSupplier
+        );
+    }
+
+    LearnToRankConfigUpdate getInferenceConfigUpdate() {
+        return inferenceConfigUpdate;
+    }
+
+    // Used in tests
+    Supplier<ModelLoadingService> modelLoadingServiceSupplier() {
+        return modelLoadingServiceSupplier;
+    }
+
+    // Used in tests
+    LocalModel getInferenceDefinition() {
+        return inferenceDefinition;
     }
 
-    private static class Builder {
+    static class Builder {
         private String modelId;
+        private LearnToRankConfigUpdate inferenceConfigUpdate;
+        private LearnToRankConfig inferenceConfig;
 
         public void setModelId(String modelId) {
             this.modelId = modelId;
         }
 
+        public void setInferenceConfigUpdate(InferenceConfigUpdate inferenceConfigUpdate) {
+            if (inferenceConfigUpdate instanceof LearnToRankConfigUpdate learnToRankConfigUpdate) {
+                this.inferenceConfigUpdate = learnToRankConfigUpdate;
+                return;
+            }
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "[%s] only allows a [%s] object to be configured",
+                    INFERENCE_CONFIG.getPreferredName(),
+                    LearnToRankConfigUpdate.NAME.getPreferredName()
+                )
+            );
+        }
+
+        void setInferenceConfig(InferenceConfig inferenceConfig) {
+            if (inferenceConfig instanceof LearnToRankConfig learnToRankConfig) {
+                this.inferenceConfig = learnToRankConfig;
+                return;
+            }
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "[%s] only allows a [%s] object to be configured",
+                    INFERENCE_CONFIG.getPreferredName(),
+                    LearnToRankConfigUpdate.NAME.getPreferredName()
+                )
+            );
+        }
+
         InferenceRescorerBuilder build(Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
-            return new InferenceRescorerBuilder(modelId, modelLoadingServiceSupplier);
+            assert inferenceConfig == null || inferenceConfigUpdate == null;
+            if (inferenceConfig != null) {
+                return new InferenceRescorerBuilder(modelId, inferenceConfig, modelLoadingServiceSupplier);
+            } else {
+                return new InferenceRescorerBuilder(modelId, inferenceConfigUpdate, modelLoadingServiceSupplier);
+            }
         }
     }
 }

+ 43 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java

@@ -7,11 +7,20 @@
 
 package org.elasticsearch.xpack.ml.inference.rescorer;
 
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.rescore.Rescorer;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -19,32 +28,64 @@ public class InferenceRescorerContext extends RescoreContext {
 
     final SearchExecutionContext executionContext;
     final LocalModel inferenceDefinition;
+    final LearnToRankConfig inferenceConfig;
 
     /**
      * @param windowSize how many documents to rescore
      * @param rescorer The rescorer to apply
+     * @param inferenceConfig The inference config containing updated and rewritten parameters
      * @param inferenceDefinition The local model inference definition, may be null during certain search phases.
      * @param executionContext The local shard search context
      */
     public InferenceRescorerContext(
         int windowSize,
         Rescorer rescorer,
+        LearnToRankConfig inferenceConfig,
         LocalModel inferenceDefinition,
         SearchExecutionContext executionContext
     ) {
         super(windowSize, rescorer);
         this.executionContext = executionContext;
         this.inferenceDefinition = inferenceDefinition;
+        this.inferenceConfig = inferenceConfig;
     }
 
-    List<FeatureExtractor> buildFeatureExtractors() {
-        assert this.inferenceDefinition != null;
+    List<FeatureExtractor> buildFeatureExtractors(IndexSearcher searcher) throws IOException {
+        assert this.inferenceDefinition != null && this.inferenceConfig != null;
         List<FeatureExtractor> featureExtractors = new ArrayList<>();
         if (this.inferenceDefinition.inputFields().isEmpty() == false) {
             featureExtractors.add(
                 new FieldValueFeatureExtractor(new ArrayList<>(this.inferenceDefinition.inputFields()), this.executionContext)
             );
         }
+        List<Weight> weights = new ArrayList<>();
+        List<String> queryFeatureNames = new ArrayList<>();
+        for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : inferenceConfig.getFeatureExtractorBuilders()) {
+            if (featureExtractorBuilder instanceof QueryExtractorBuilder queryExtractorBuilder) {
+                Query query = executionContext.toQuery(queryExtractorBuilder.query().getParsedQuery()).query();
+                Weight weight = searcher.rewrite(query).createWeight(searcher, ScoreMode.COMPLETE, 1f);
+                weights.add(weight);
+                queryFeatureNames.add(queryExtractorBuilder.featureName());
+            }
+        }
+        if (weights.isEmpty() == false) {
+            featureExtractors.add(new QueryFeatureExtractor(queryFeatureNames, weights));
+        }
+
         return featureExtractors;
     }
+
+    @Override
+    public List<ParsedQuery> getParsedQueries() {
+        if (this.inferenceConfig == null) {
+            return List.of();
+        }
+        List<ParsedQuery> parsedQueries = new ArrayList<>();
+        for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : inferenceConfig.getFeatureExtractorBuilders()) {
+            if (featureExtractorBuilder instanceof QueryExtractorBuilder queryExtractorBuilder) {
+                parsedQueries.add(executionContext.toQuery(queryExtractorBuilder.query().getParsedQuery()));
+            }
+        }
+        return parsedQueries;
+    }
 }

+ 78 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/QueryFeatureExtractor.java

@@ -0,0 +1,78 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.rescorer;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.DisiPriorityQueue;
+import org.apache.lucene.search.DisiWrapper;
+import org.apache.lucene.search.DisjunctionDISIApproximation;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Extracts query features, e.g. _scores, from the provided weights and featureNames.
+ * For every document provided, this extractor iterates with the constructed scorers and collects the _score (if matched) for the
+ * respective feature name.
+ */
+public class QueryFeatureExtractor implements FeatureExtractor {
+
+    private final List<String> featureNames;
+    private final List<Weight> weights;
+    private final List<Scorer> scorers;
+    private DisjunctionDISIApproximation rankerIterator;
+
+    public QueryFeatureExtractor(List<String> featureNames, List<Weight> weights) {
+        if (featureNames.size() != weights.size()) {
+            throw new IllegalArgumentException("[featureNames] and [weights] must be the same size.");
+        }
+        this.featureNames = featureNames;
+        this.weights = weights;
+        this.scorers = new ArrayList<>(weights.size());
+    }
+
+    @Override
+    public void setNextReader(LeafReaderContext segmentContext) throws IOException {
+        DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
+        scorers.clear();
+        for (Weight weight : weights) {
+            if (weight == null) {
+                scorers.add(null);
+                continue;
+            }
+            Scorer scorer = weight.scorer(segmentContext);
+            if (scorer != null) {
+                disiPriorityQueue.add(new DisiWrapper(scorer));
+            }
+            scorers.add(scorer);
+        }
+        rankerIterator = new DisjunctionDISIApproximation(disiPriorityQueue);
+    }
+
+    @Override
+    public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
+        rankerIterator.advance(docId);
+        for (int i = 0; i < featureNames.size(); i++) {
+            Scorer scorer = scorers.get(i);
+            // Do we have a scorer, and does it match the provided document?
+            if (scorer != null && scorer.docID() == docId) {
+                featureMap.put(featureNames.get(i), scorer.score());
+            }
+        }
+    }
+
+    @Override
+    public List<String> featureNames() {
+        return featureNames;
+    }
+
+}

+ 175 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java

@@ -11,17 +11,41 @@ import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TotalHits;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.index.query.CoordinatorRewriteContext;
 import org.elasticsearch.index.query.DataRewriteContext;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.test.AbstractBuilderTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdateTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -29,19 +53,48 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.io.IOException;
+import java.lang.reflect.Method;
 import java.util.List;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.in;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCase {
 
+    private static final String GOOD_MODEL = "modelId";
+    private static final String BAD_MODEL = "badModel";
+    private static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder()
+        .setModelId(GOOD_MODEL)
+        .setInput(new TrainedModelInput(List.of("field1", "field2")))
+        .setEstimatedOperations(1)
+        .setModelSize(2)
+        .setModelType(TrainedModelType.TREE_ENSEMBLE)
+        .setInferenceConfig(new LearnToRankConfig(null, null))
+        .build();
+    private static final TrainedModelConfig BAD_MODEL_CONFIG = TrainedModelConfig.builder()
+        .setModelId(BAD_MODEL)
+        .setInput(new TrainedModelInput(List.of("field1", "field2")))
+        .setEstimatedOperations(1)
+        .setModelSize(2)
+        .setModelType(TrainedModelType.TREE_ENSEMBLE)
+        .setInferenceConfig(new RegressionConfig(null, null))
+        .build();
+
     public void testMustRewrite() {
         TestModelLoader testModelLoader = new TestModelLoader();
-        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("modelId", () -> testModelLoader);
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            GOOD_MODEL,
+            LearnToRankConfigTests.randomLearnToRankConfig(),
+            () -> testModelLoader
+        );
         SearchExecutionContext context = createSearchExecutionContext();
         InferenceRescorerContext inferenceRescorerContext = inferenceRescorerBuilder.innerBuildContext(randomIntBetween(1, 30), context);
         IllegalStateException e = expectThrows(
@@ -58,20 +111,124 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas
 
     public void testRewriteOnCoordinator() throws IOException {
         TestModelLoader testModelLoader = new TestModelLoader();
-        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("modelId", () -> testModelLoader);
+        LearnToRankConfigUpdate ltru = new LearnToRankConfigUpdate(
+            2,
+            List.of(new QueryExtractorBuilder("all", QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery())))
+        );
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(GOOD_MODEL, ltru, () -> testModelLoader);
+        inferenceRescorerBuilder.windowSize(4);
         CoordinatorRewriteContext context = createCoordinatorRewriteContext(
             new DateFieldMapper.DateFieldType("@timestamp"),
             randomIntBetween(0, 1_100_000),
             randomIntBetween(1_500_000, Integer.MAX_VALUE)
         );
-        InferenceRescorerBuilder rewritten = (InferenceRescorerBuilder) inferenceRescorerBuilder.rewrite(context);
-        assertSame(inferenceRescorerBuilder, rewritten);
-        assertFalse(context.hasAsyncActions());
+        InferenceRescorerBuilder rewritten = rewriteAndFetch(inferenceRescorerBuilder, context);
+        assertThat(rewritten.getInferenceConfig(), not(nullValue()));
+        assertThat(rewritten.getInferenceConfig().getNumTopFeatureImportanceValues(), equalTo(2));
+        assertThat(
+            "all",
+            is(
+                in(
+                    rewritten.getInferenceConfig()
+                        .getFeatureExtractorBuilders()
+                        .stream()
+                        .map(LearnToRankFeatureExtractorBuilder::featureName)
+                        .toList()
+                )
+            )
+        );
+        assertThat(rewritten.getInferenceConfigUpdate(), is(nullValue()));
+        assertThat(rewritten.windowSize(), equalTo(4));
+    }
+
+    public void testRewriteOnCoordinatorWithBadModel() throws IOException {
+        TestModelLoader testModelLoader = new TestModelLoader();
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            BAD_MODEL,
+            randomBoolean() ? null : LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate(),
+            () -> testModelLoader
+        );
+        CoordinatorRewriteContext context = createCoordinatorRewriteContext(
+            new DateFieldMapper.DateFieldType("@timestamp"),
+            randomIntBetween(0, 1_100_000),
+            randomIntBetween(1_500_000, Integer.MAX_VALUE)
+        );
+        ElasticsearchStatusException ex = expectThrows(
+            ElasticsearchStatusException.class,
+            () -> rewriteAndFetch(inferenceRescorerBuilder, context)
+        );
+        assertThat(ex.status(), equalTo(RestStatus.BAD_REQUEST));
+    }
+
+    public void testRewriteOnCoordinatorWithMissingModel() {
+        TestModelLoader testModelLoader = new TestModelLoader();
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            "missing_model",
+            randomBoolean() ? null : LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate(),
+            () -> testModelLoader
+        );
+        CoordinatorRewriteContext context = createCoordinatorRewriteContext(
+            new DateFieldMapper.DateFieldType("@timestamp"),
+            randomIntBetween(0, 1_100_000),
+            randomIntBetween(1_500_000, Integer.MAX_VALUE)
+        );
+        expectThrows(ResourceNotFoundException.class, () -> rewriteAndFetch(inferenceRescorerBuilder, context));
+    }
+
+    public void testSearchRewrite() throws IOException {
+        TestModelLoader testModelLoader = new TestModelLoader();
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            GOOD_MODEL,
+            LearnToRankConfigTests.randomLearnToRankConfig(),
+            () -> testModelLoader
+        );
+        QueryRewriteContext context = createSearchExecutionContext();
+        InferenceRescorerBuilder rewritten = (InferenceRescorerBuilder) Rewriteable.rewrite(inferenceRescorerBuilder, context, true);
+        assertThat(rewritten.modelLoadingServiceSupplier(), is(notNullValue()));
+
+        inferenceRescorerBuilder = new InferenceRescorerBuilder(GOOD_MODEL, LearnToRankConfigTests.randomLearnToRankConfig(), localModel());
+
+        rewritten = (InferenceRescorerBuilder) Rewriteable.rewrite(inferenceRescorerBuilder, context, true);
+        assertThat(rewritten.modelLoadingServiceSupplier(), is(nullValue()));
+        assertThat(rewritten.getInferenceDefinition(), is(notNullValue()));
+    }
+
+    protected InferenceRescorerBuilder rewriteAndFetch(RescorerBuilder<InferenceRescorerBuilder> builder, QueryRewriteContext context) {
+        PlainActionFuture<RescorerBuilder<InferenceRescorerBuilder>> future = new PlainActionFuture<>();
+        Rewriteable.rewriteAndFetch(builder, context, future);
+        return (InferenceRescorerBuilder) future.actionGet();
+    }
+
+    @Override
+    protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException {
+        return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class))
+            && (args[0] instanceof GetTrainedModelsAction);
+    }
+
+    @Override
+    protected Object simulateMethod(Method method, Object[] args) {
+        GetTrainedModelsAction.Request request = (GetTrainedModelsAction.Request) args[1];
+        @SuppressWarnings("unchecked")  // We matched the method above.
+        ActionListener<GetTrainedModelsAction.Response> listener = (ActionListener<GetTrainedModelsAction.Response>) args[2];
+        if (request.getResourceId().equals(GOOD_MODEL)) {
+            listener.onResponse(GetTrainedModelsAction.Response.builder().setModels(List.of(GOOD_MODEL_CONFIG)).build());
+            return null;
+        }
+        if (request.getResourceId().equals(BAD_MODEL)) {
+            listener.onResponse(GetTrainedModelsAction.Response.builder().setModels(List.of(BAD_MODEL_CONFIG)).build());
+            return null;
+        }
+        listener.onFailure(ExceptionsHelper.missingTrainedModel(request.getResourceId()));
+        return null;
     }
 
     public void testRewriteOnShard() throws IOException {
         TestModelLoader testModelLoader = new TestModelLoader();
-        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("modelId", () -> testModelLoader);
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            GOOD_MODEL,
+            (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(),
+            () -> testModelLoader
+        );
         SearchExecutionContext searchExecutionContext = createSearchExecutionContext();
         InferenceRescorerBuilder rewritten = (InferenceRescorerBuilder) inferenceRescorerBuilder.rewrite(createSearchExecutionContext());
         assertSame(inferenceRescorerBuilder, rewritten);
@@ -80,7 +237,11 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas
 
     public void testRewriteAndFetchOnDataNode() throws IOException {
         TestModelLoader testModelLoader = new TestModelLoader();
-        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("modelId", () -> testModelLoader);
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            GOOD_MODEL,
+            (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(),
+            () -> testModelLoader
+        );
         boolean setWindowSize = randomBoolean();
         if (setWindowSize) {
             inferenceRescorerBuilder.windowSize(42);
@@ -94,16 +255,20 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas
         }
     }
 
-    public void testBuildContext() {
+    public void testBuildContext() throws Exception {
         LocalModel localModel = localModel();
         List<String> inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME);
         when(localModel.inputFields()).thenReturn(inputFields);
         SearchExecutionContext context = createSearchExecutionContext();
-        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("test_model", localModel);
+        InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder(
+            GOOD_MODEL,
+            (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(),
+            localModel
+        );
         InferenceRescorerContext rescoreContext = inferenceRescorerBuilder.innerBuildContext(20, context);
         assertNotNull(rescoreContext);
         assertThat(rescoreContext.getWindowSize(), equalTo(20));
-        List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors();
+        List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher());
         assertThat(featureExtractors, hasSize(1));
         assertThat(
             featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(),

+ 101 - 7
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderSerializationTests.java

@@ -9,12 +9,25 @@ package org.elasticsearch.xpack.ml.inference.rescorer;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdateTests;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.function.Supplier;
 
 import static org.elasticsearch.search.rank.RankBuilder.WINDOW_SIZE_FIELD;
@@ -59,7 +72,17 @@ public class InferenceRescorerBuilderSerializationTests extends AbstractBWCSeria
 
     @Override
     protected InferenceRescorerBuilder createTestInstance() {
-        InferenceRescorerBuilder builder = new InferenceRescorerBuilder(randomAlphaOfLength(10), (Supplier<ModelLoadingService>) null);
+        InferenceRescorerBuilder builder = randomBoolean()
+            ? new InferenceRescorerBuilder(
+                randomAlphaOfLength(10),
+                randomBoolean() ? null : LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate(),
+                null
+            )
+            : new InferenceRescorerBuilder(
+                randomAlphaOfLength(10),
+                LearnToRankConfigTests.randomLearnToRankConfig(),
+                (Supplier<ModelLoadingService>) null
+            );
         if (randomBoolean()) {
             builder.windowSize(randomIntBetween(1, 10000));
         }
@@ -68,15 +91,44 @@ public class InferenceRescorerBuilderSerializationTests extends AbstractBWCSeria
 
     @Override
     protected InferenceRescorerBuilder mutateInstance(InferenceRescorerBuilder instance) throws IOException {
-        int i = randomInt(1);
+        int i = randomInt(3);
         return switch (i) {
-            case 0 -> new InferenceRescorerBuilder(
-                randomValueOtherThan(instance.getModelId(), () -> randomAlphaOfLength(10)),
-                (Supplier<ModelLoadingService>) null
-            );
-            case 1 -> new InferenceRescorerBuilder(instance.getModelId(), (Supplier<ModelLoadingService>) null).windowSize(
+            case 0 -> {
+                InferenceRescorerBuilder builder = new InferenceRescorerBuilder(
+                    randomValueOtherThan(instance.getModelId(), () -> randomAlphaOfLength(10)),
+                    instance.getInferenceConfigUpdate(),
+                    null
+                );
+                if (instance.windowSize() != null) {
+                    builder.windowSize(instance.windowSize());
+                }
+                yield builder;
+            }
+            case 1 -> new InferenceRescorerBuilder(instance.getModelId(), instance.getInferenceConfigUpdate(), null).windowSize(
                 randomValueOtherThan(instance.windowSize(), () -> randomIntBetween(1, 10000))
             );
+            case 2 -> {
+                InferenceRescorerBuilder builder = new InferenceRescorerBuilder(
+                    instance.getModelId(),
+                    randomValueOtherThan(instance.getInferenceConfigUpdate(), LearnToRankConfigUpdateTests::randomLearnToRankConfigUpdate),
+                    null
+                );
+                if (instance.windowSize() != null) {
+                    builder.windowSize(instance.windowSize());
+                }
+                yield builder;
+            }
+            case 3 -> {
+                InferenceRescorerBuilder builder = new InferenceRescorerBuilder(
+                    instance.getModelId(),
+                    randomValueOtherThan(instance.getInferenceConfig(), LearnToRankConfigTests::randomLearnToRankConfig),
+                    (Supplier<ModelLoadingService>) null
+                );
+                if (instance.windowSize() != null) {
+                    builder.windowSize(instance.windowSize());
+                }
+                yield builder;
+            }
             default -> throw new AssertionError("Unexpected random test case");
         };
     }
@@ -85,4 +137,46 @@ public class InferenceRescorerBuilderSerializationTests extends AbstractBWCSeria
     protected InferenceRescorerBuilder mutateInstanceForVersion(InferenceRescorerBuilder instance, TransportVersion version) {
         return instance;
     }
+
+    public void testIncorrectInferenceConfigUpdateType() {
+        InferenceRescorerBuilder.Builder builder = new InferenceRescorerBuilder.Builder();
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> builder.setInferenceConfigUpdate(ClassificationConfigUpdateTests.randomClassificationConfigUpdate())
+        );
+        // Should not throw
+        builder.setInferenceConfigUpdate(LearnToRankConfigUpdateTests.randomLearnToRankConfigUpdate());
+    }
+
+    public void testIncorrectInferenceConfigType() {
+        InferenceRescorerBuilder.Builder builder = new InferenceRescorerBuilder.Builder();
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> builder.setInferenceConfig(ClassificationConfigTests.randomClassificationConfig())
+        );
+        // Should not throw
+        builder.setInferenceConfig(LearnToRankConfigTests.randomLearnToRankConfig());
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+    @Override
+    protected NamedWriteableRegistry writableRegistry() {
+        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
+        return new NamedWriteableRegistry(namedWriteables);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return writableRegistry();
+    }
 }

+ 127 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/QueryFeatureExtractorTests.java

@@ -0,0 +1,127 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.rescorer;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.search.similarities.ClassicSimilarity;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.test.AbstractBuilderTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.anEmptyMap;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.hamcrest.Matchers.hasKey;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+
+public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {
+
+    private Directory dir;
+    private IndexReader reader;
+    private IndexSearcher searcher;
+
+    private void addDocs(String[] textValues, int[] numberValues) throws IOException {
+        dir = newDirectory();
+        try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) {
+            for (int i = 0; i < textValues.length; i++) {
+                Document doc = new Document();
+                doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO));
+                doc.add(new IntField(INT_FIELD_NAME, numberValues[i], Field.Store.YES));
+                indexWriter.addDocument(doc);
+                if (randomBoolean()) {
+                    indexWriter.flush();
+                }
+            }
+            reader = indexWriter.getReader();
+        }
+        searcher = newSearcher(reader);
+        searcher.setSimilarity(new ClassicSimilarity());
+    }
+
+    public void testQueryExtractor() throws IOException {
+        addDocs(
+            new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
+            new int[] { 5, 10, 12, 11 }
+        );
+        QueryRewriteContext ctx = createQueryRewriteContext();
+        List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
+            new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
+                .rewrite(ctx),
+            new QueryExtractorBuilder(
+                "number_score",
+                QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
+            ).rewrite(ctx),
+            new QueryExtractorBuilder(
+                "matching_none",
+                QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
+            ).rewrite(ctx),
+            new QueryExtractorBuilder(
+                "matching_missing_field",
+                QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
+            ).rewrite(ctx)
+        );
+        SearchExecutionContext dummySEC = createSearchExecutionContext();
+        List<Weight> weights = new ArrayList<>();
+        List<String> featureNames = new ArrayList<>();
+        for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
+            Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
+            Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
+            weights.add(weight);
+            featureNames.add(qeb.featureName());
+        }
+        QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
+        List<Map<String, Object>> extractedFeatures = new ArrayList<>();
+        for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
+            int maxDoc = leafReaderContext.reader().maxDoc();
+            queryFeatureExtractor.setNextReader(leafReaderContext);
+            for (int i = 0; i < maxDoc; i++) {
+                Map<String, Object> featureMap = new HashMap<>();
+                queryFeatureExtractor.addFeatures(featureMap, i);
+                extractedFeatures.add(featureMap);
+            }
+        }
+        assertThat(extractedFeatures, hasSize(4));
+        // Should never add features for queries that don't match a document or on documents where the field is missing
+        for (Map<String, Object> features : extractedFeatures) {
+            assertThat(features, not(hasKey("matching_none")));
+            assertThat(features, not(hasKey("matching_missing_field")));
+        }
+        // First two only match the text field
+        assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
+        assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
+        assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
+        assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
+        // Only matches the range query
+        assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
+        assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
+        // No query matches
+        assertThat(extractedFeatures.get(3), anEmptyMap());
+        reader.close();
+        dir.close();
+    }
+
+}