Browse Source

[ML] add new `learn_to_rank` inference config (#97198)

This adds a new inference config for utilizing a `learn_to_rank`
configuration within the `inference_rescorer` context.

This inference config allows setting typical `regression` config
settings, with the addition of a new "feature_extractors" named objects.
Right now, there are no feature extractor builders. But there will be
for calculating and extracting user provided and query & document
interactions.

Additionally, this commit cleans up the feature extraction in the
rescorer, refactoring out the field value extraction into an
encapsulated class.
Benjamin Trent 2 years ago
parent
commit
e99f36e8d9
24 changed files with 1167 additions and 80 deletions
  1. 74 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java
  2. 12 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java
  3. 201 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java
  4. 228 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java
  5. 15 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java
  6. 22 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/LearnToRankFeatureExtractorBuilder.java
  7. 4 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java
  8. 203 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigTests.java
  9. 120 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java
  10. 1 1
      x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java
  11. 1 1
      x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java
  12. 9 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  13. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java
  14. 12 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java
  15. 141 35
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java
  16. 22 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/FeatureExtractor.java
  17. 64 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/FieldValueFeatureExtractor.java
  18. 8 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java
  19. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java
  20. 11 16
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java
  21. 9 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java
  22. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java
  23. 6 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java
  24. 1 1
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_rescore.yml

+ 74 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlLTRNamedXContentProvider.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.plugins.spi.NamedXContentProvider;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Only the LTR named writeables and xcontent. Remove and combine with inference provider
+ * when feature flag is removed
+ */
+public class MlLTRNamedXContentProvider implements NamedXContentProvider {
+
+    @Override
+    public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        // Lenient Inference Config
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                LenientlyParsedInferenceConfig.class,
+                LearnToRankConfig.NAME,
+                LearnToRankConfig::fromXContentLenient
+            )
+        );
+        // Strict Inference Config
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                StrictlyParsedInferenceConfig.class,
+                LearnToRankConfig.NAME,
+                LearnToRankConfig::fromXContentStrict
+            )
+        );
+        // Inference Config Update
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                InferenceConfigUpdate.class,
+                LearnToRankConfigUpdate.NAME,
+                LearnToRankConfigUpdate::fromXContentStrict
+            )
+        );
+        return namedXContent;
+    }
+
+    public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
+        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
+        // Inference config
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(InferenceConfig.class, LearnToRankConfig.NAME.getPreferredName(), LearnToRankConfig::new)
+        );
+        // Inference config update
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                InferenceConfigUpdate.class,
+                LearnToRankConfigUpdate.NAME.getPreferredName(),
+                LearnToRankConfigUpdate::new
+            )
+        );
+        return namedWriteables;
+    }
+}

+ 12 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

@@ -40,4 +40,16 @@ public interface InferenceConfig extends NamedXContentObject, VersionedNamedWrit
     String getResultsField();
 
     boolean isAllocateOnly();
+
+    default boolean supportsIngestPipeline() {
+        return true;
+    }
+
+    default boolean supportsPipelineAggregation() {
+        return true;
+    }
+
+    default boolean supportsSearchRescorer() {
+        return false;
+    }
 }

+ 201 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java

@@ -0,0 +1,201 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.ObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class LearnToRankConfig extends RegressionConfig {
+
+    public static final ParseField NAME = new ParseField("learn_to_rank");
+    static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersion.current();
+    public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
+    public static final ParseField FEATURE_EXTRACTORS = new ParseField("feature_extractors");
+    public static LearnToRankConfig EMPTY_PARAMS = new LearnToRankConfig(null, null);
+
+    private static final ObjectParser<LearnToRankConfig.Builder, Boolean> LENIENT_PARSER = createParser(true);
+    private static final ObjectParser<LearnToRankConfig.Builder, Boolean> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<LearnToRankConfig.Builder, Boolean> createParser(boolean lenient) {
+        ObjectParser<LearnToRankConfig.Builder, Boolean> parser = new ObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            LearnToRankConfig.Builder::new
+        );
+        parser.declareInt(Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
+        parser.declareNamedObjects(
+            Builder::setLearnToRankFeatureExtractorBuilders,
+            (p, c, n) -> p.namedObject(LearnToRankFeatureExtractorBuilder.class, n, lenient),
+            b -> {},
+            FEATURE_EXTRACTORS
+        );
+        return parser;
+    }
+
+    public static LearnToRankConfig fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    public static LearnToRankConfig fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null).build();
+    }
+
+    private final List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilders;
+
+    public LearnToRankConfig(Integer numTopFeatureImportanceValues, List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilders) {
+        super(DEFAULT_RESULTS_FIELD, numTopFeatureImportanceValues);
+        if (featureExtractorBuilders != null) {
+            Set<String> featureNames = featureExtractorBuilders.stream()
+                .map(LearnToRankFeatureExtractorBuilder::featureName)
+                .collect(Collectors.toSet());
+            if (featureNames.size() < featureExtractorBuilders.size()) {
+                throw new IllegalArgumentException(
+                    "[" + FEATURE_EXTRACTORS.getPreferredName() + "] contains duplicate [feature_name] values"
+                );
+            }
+        }
+        this.featureExtractorBuilders = featureExtractorBuilders == null ? List.of() : featureExtractorBuilders;
+    }
+
+    public LearnToRankConfig(StreamInput in) throws IOException {
+        super(in);
+        this.featureExtractorBuilders = in.readNamedWriteableList(LearnToRankFeatureExtractorBuilder.class);
+    }
+
+    public List<LearnToRankFeatureExtractorBuilder> getFeatureExtractorBuilders() {
+        return featureExtractorBuilders;
+    }
+
+    @Override
+    public String getResultsField() {
+        return DEFAULT_RESULTS_FIELD;
+    }
+
+    @Override
+    public boolean isAllocateOnly() {
+        return false;
+    }
+
+    @Override
+    public boolean supportsIngestPipeline() {
+        return false;
+    }
+
+    @Override
+    public boolean supportsPipelineAggregation() {
+        return false;
+    }
+
+    @Override
+    public boolean supportsSearchRescorer() {
+        return true;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        out.writeNamedWriteableList(featureExtractorBuilders);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), getNumTopFeatureImportanceValues());
+        if (featureExtractorBuilders.isEmpty() == false) {
+            NamedXContentObjectHelper.writeNamedObjects(
+                builder,
+                params,
+                true,
+                FEATURE_EXTRACTORS.getPreferredName(),
+                featureExtractorBuilders
+            );
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
+        LearnToRankConfig that = (LearnToRankConfig) o;
+        return Objects.equals(featureExtractorBuilders, that.featureExtractorBuilders);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), featureExtractorBuilders);
+    }
+
+    @Override
+    public boolean isTargetTypeSupported(TargetType targetType) {
+        return TargetType.REGRESSION.equals(targetType);
+    }
+
+    @Override
+    public Version getMinimalSupportedNodeVersion() {
+        return Version.CURRENT;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedTransportVersion() {
+        return MIN_SUPPORTED_TRANSPORT_VERSION;
+    }
+
+    public static class Builder {
+        private Integer numTopFeatureImportanceValues;
+        private List<LearnToRankFeatureExtractorBuilder> learnToRankFeatureExtractorBuilders;
+
+        Builder() {}
+
+        Builder(LearnToRankConfig config) {
+            this.numTopFeatureImportanceValues = config.getNumTopFeatureImportanceValues();
+            this.learnToRankFeatureExtractorBuilders = config.featureExtractorBuilders;
+        }
+
+        public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
+            this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
+            return this;
+        }
+
+        public Builder setLearnToRankFeatureExtractorBuilders(
+            List<LearnToRankFeatureExtractorBuilder> learnToRankFeatureExtractorBuilders
+        ) {
+            this.learnToRankFeatureExtractorBuilders = learnToRankFeatureExtractorBuilders;
+            return this;
+        }
+
+        public LearnToRankConfig build() {
+            return new LearnToRankConfig(numTopFeatureImportanceValues, learnToRankFeatureExtractorBuilders);
+        }
+    }
+}

+ 228 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdate.java

@@ -0,0 +1,228 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.ObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig.FEATURE_EXTRACTORS;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
+
+public class LearnToRankConfigUpdate implements InferenceConfigUpdate, NamedXContentObject {
+
+    public static final ParseField NAME = LearnToRankConfig.NAME;
+
+    public static LearnToRankConfigUpdate EMPTY_PARAMS = new LearnToRankConfigUpdate(null, null);
+
+    public static LearnToRankConfigUpdate fromConfig(LearnToRankConfig config) {
+        return new LearnToRankConfigUpdate(config.getNumTopFeatureImportanceValues(), config.getFeatureExtractorBuilders());
+    }
+
+    private static final ObjectParser<LearnToRankConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<LearnToRankConfigUpdate.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<LearnToRankConfigUpdate.Builder, Void> parser = new ObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            LearnToRankConfigUpdate.Builder::new
+        );
+        parser.declareInt(LearnToRankConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
+        parser.declareNamedObjects(
+            LearnToRankConfigUpdate.Builder::setFeatureExtractorBuilders,
+            (p, c, n) -> p.namedObject(LearnToRankFeatureExtractorBuilder.class, n, false),
+            b -> {},
+            FEATURE_EXTRACTORS
+        );
+        return parser;
+    }
+
+    public static LearnToRankConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    private final Integer numTopFeatureImportanceValues;
+    private final List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilderList;
+
+    public LearnToRankConfigUpdate(
+        Integer numTopFeatureImportanceValues,
+        List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilders
+    ) {
+        if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
+            throw new IllegalArgumentException(
+                "[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + "] must be greater than or equal to 0"
+            );
+        }
+        if (featureExtractorBuilders != null) {
+            Set<String> featureNames = featureExtractorBuilders.stream()
+                .map(LearnToRankFeatureExtractorBuilder::featureName)
+                .collect(Collectors.toSet());
+            if (featureNames.size() < featureExtractorBuilders.size()) {
+                throw new IllegalArgumentException(
+                    "[" + FEATURE_EXTRACTORS.getPreferredName() + "] contains duplicate [feature_name] values"
+                );
+            }
+        }
+        this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
+        this.featureExtractorBuilderList = featureExtractorBuilders == null ? List.of() : featureExtractorBuilders;
+    }
+
+    public LearnToRankConfigUpdate(StreamInput in) throws IOException {
+        this.numTopFeatureImportanceValues = in.readOptionalVInt();
+        this.featureExtractorBuilderList = in.readNamedWriteableList(LearnToRankFeatureExtractorBuilder.class);
+    }
+
+    public Integer getNumTopFeatureImportanceValues() {
+        return numTopFeatureImportanceValues;
+    }
+
+    @Override
+    public String getResultsField() {
+        return DEFAULT_RESULTS_FIELD;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder().setNumTopFeatureImportanceValues(numTopFeatureImportanceValues);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalVInt(numTopFeatureImportanceValues);
+        out.writeNamedWriteableList(featureExtractorBuilderList);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return LearnToRankConfig.MIN_SUPPORTED_TRANSPORT_VERSION;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (numTopFeatureImportanceValues != null) {
+            builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
+        }
+        if (featureExtractorBuilderList.isEmpty() == false) {
+            NamedXContentObjectHelper.writeNamedObjects(
+                builder,
+                params,
+                true,
+                FEATURE_EXTRACTORS.getPreferredName(),
+                featureExtractorBuilderList
+            );
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        LearnToRankConfigUpdate that = (LearnToRankConfigUpdate) o;
+        return Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
+            && Objects.equals(this.featureExtractorBuilderList, that.featureExtractorBuilderList);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(numTopFeatureImportanceValues, featureExtractorBuilderList);
+    }
+
+    @Override
+    public LearnToRankConfig apply(InferenceConfig originalConfig) {
+        if (originalConfig instanceof LearnToRankConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a inference request of type [{}]",
+                originalConfig.getName(),
+                getName()
+            );
+        }
+
+        LearnToRankConfig ltrConfig = (LearnToRankConfig) originalConfig;
+        if (isNoop(ltrConfig)) {
+            return ltrConfig;
+        }
+        LearnToRankConfig.Builder builder = new LearnToRankConfig.Builder(ltrConfig);
+        if (numTopFeatureImportanceValues != null) {
+            builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues);
+        }
+        if (featureExtractorBuilderList.isEmpty() == false) {
+            Map<String, LearnToRankFeatureExtractorBuilder> existingExtractors = ltrConfig.getFeatureExtractorBuilders()
+                .stream()
+                .collect(Collectors.toMap(LearnToRankFeatureExtractorBuilder::featureName, f -> f));
+            featureExtractorBuilderList.forEach(f -> existingExtractors.put(f.featureName(), f));
+            builder.setLearnToRankFeatureExtractorBuilders(new ArrayList<>(existingExtractors.values()));
+        }
+        return builder.build();
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig inferenceConfig) {
+        return inferenceConfig instanceof LearnToRankConfig;
+    }
+
+    boolean isNoop(LearnToRankConfig originalConfig) {
+        return (numTopFeatureImportanceValues == null || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues)
+            && (featureExtractorBuilderList.isEmpty()
+                || Objects.equals(originalConfig.getFeatureExtractorBuilders(), featureExtractorBuilderList));
+    }
+
+    public static class Builder implements InferenceConfigUpdate.Builder<Builder, LearnToRankConfigUpdate> {
+        private Integer numTopFeatureImportanceValues;
+        private List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilderList;
+
+        @Override
+        public Builder setResultsField(String resultsField) {
+            assert false : "results field should never be set in ltr config";
+            return this;
+        }
+
+        public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
+            this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
+            return this;
+        }
+
+        public Builder setFeatureExtractorBuilders(List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilderList) {
+            this.featureExtractorBuilderList = featureExtractorBuilderList;
+            return this;
+        }
+
+        @Override
+        public LearnToRankConfigUpdate build() {
+            return new LearnToRankConfigUpdate(numTopFeatureImportanceValues, featureExtractorBuilderList);
+        }
+    }
+}

+ 15 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java

@@ -29,4 +29,19 @@ public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParse
      * @return the model tokenization parameters
      */
     Tokenization getTokenization();
+
+    @Override
+    default boolean supportsIngestPipeline() {
+        return true;
+    }
+
+    @Override
+    default boolean supportsPipelineAggregation() {
+        return false;
+    }
+
+    @Override
+    default boolean supportsSearchRescorer() {
+        return false;
+    }
 }

+ 22 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ltr/LearnToRankFeatureExtractorBuilder.java

@@ -0,0 +1,22 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+public interface LearnToRankFeatureExtractorBuilder extends NamedXContentObject, NamedWriteable {
+
+    ParseField FEATURE_NAME = new ParseField("feature_name");
+
+    /**
+     * @return The input feature that this extractor satisfies
+     */
+    String featureName();
+}

+ 4 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java

@@ -74,14 +74,16 @@ public abstract class InferenceConfigItemTestCase<T extends VersionedNamedWritea
     protected NamedXContentRegistry xContentRegistry() {
         List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
         namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
         namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
         return new NamedXContentRegistry(namedXContent);
     }
 
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
-        return new NamedWriteableRegistry(entries);
+        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(namedWriteables);
     }
 
     @Override

+ 203 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigTests.java

@@ -0,0 +1,203 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+
+public class LearnToRankConfigTests extends InferenceConfigItemTestCase<LearnToRankConfig> {
+    private boolean lenient;
+
+    public static LearnToRankConfig randomLearnToRankConfig() {
+        return new LearnToRankConfig(
+            randomBoolean() ? null : randomIntBetween(0, 10),
+            randomBoolean()
+                ? null
+                : Stream.generate(() -> new TestValueExtractor(randomAlphaOfLength(10))).limit(randomInt(5)).collect(Collectors.toList())
+        );
+    }
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected LearnToRankConfig createTestInstance() {
+        return randomLearnToRankConfig();
+    }
+
+    @Override
+    protected LearnToRankConfig mutateInstance(LearnToRankConfig instance) {
+        return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> field.isEmpty() == false;
+    }
+
+    @Override
+    protected Writeable.Reader<LearnToRankConfig> instanceReader() {
+        return LearnToRankConfig::new;
+    }
+
+    @Override
+    protected LearnToRankConfig doParseInstance(XContentParser parser) throws IOException {
+        return lenient ? LearnToRankConfig.fromXContentLenient(parser) : LearnToRankConfig.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    @Override
+    protected LearnToRankConfig mutateInstanceForVersion(LearnToRankConfig instance, TransportVersion version) {
+        return instance;
+    }
+
+    public void testDuplicateFeatureNames() {
+        List<LearnToRankFeatureExtractorBuilder> featureExtractorBuilderList = List.of(
+            new TestValueExtractor("foo"),
+            new TestValueExtractor("foo")
+        );
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> new LearnToRankConfig(randomBoolean() ? null : randomIntBetween(0, 10), featureExtractorBuilderList)
+        );
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                LearnToRankFeatureExtractorBuilder.class,
+                TestValueExtractor.NAME,
+                TestValueExtractor::fromXContent
+            )
+        );
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                LearnToRankFeatureExtractorBuilder.class,
+                TestValueExtractor.NAME.getPreferredName(),
+                TestValueExtractor::new
+            )
+        );
+        return new NamedWriteableRegistry(namedWriteables);
+    }
+
+    static class TestValueExtractor implements LearnToRankFeatureExtractorBuilder {
+        public static final ParseField NAME = new ParseField("test");
+        private final String featureName;
+
+        private static final ConstructingObjectParser<TestValueExtractor, Void> PARSER = new ConstructingObjectParser<>(
+            NAME.getPreferredName(),
+            a -> new TestValueExtractor((String) a[0])
+        );
+        private static final ConstructingObjectParser<TestValueExtractor, Void> LENIENT_PARSER = new ConstructingObjectParser<>(
+            NAME.getPreferredName(),
+            true,
+            a -> new TestValueExtractor((String) a[0])
+        );
+        static {
+            PARSER.declareString(constructorArg(), FEATURE_NAME);
+            LENIENT_PARSER.declareString(constructorArg(), FEATURE_NAME);
+        }
+
+        public static TestValueExtractor fromXContent(XContentParser parser, Object context) {
+            boolean lenient = Boolean.TRUE.equals(context);
+            return lenient ? LENIENT_PARSER.apply(parser, null) : PARSER.apply(parser, null);
+        }
+
+        TestValueExtractor(StreamInput in) throws IOException {
+            this.featureName = in.readString();
+        }
+
+        TestValueExtractor(String featureName) {
+            this.featureName = featureName;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(FEATURE_NAME.getPreferredName(), featureName);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME.getPreferredName();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(featureName);
+        }
+
+        @Override
+        public String featureName() {
+            return featureName;
+        }
+
+        @Override
+        public String getName() {
+            return NAME.getPreferredName();
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            TestValueExtractor that = (TestValueExtractor) o;
+            return Objects.equals(featureName, that.featureName);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(featureName);
+        }
+    }
+}

+ 120 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfigUpdateTests.java

@@ -0,0 +1,120 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.in;
+import static org.hamcrest.Matchers.is;
+
+public class LearnToRankConfigUpdateTests extends AbstractBWCSerializationTestCase<LearnToRankConfigUpdate> {
+
+    public static LearnToRankConfigUpdate randomLearnToRankConfigUpdate() {
+        return new LearnToRankConfigUpdate(randomBoolean() ? null : randomIntBetween(0, 10), null);
+    }
+
+    public void testApply() {
+        LearnToRankConfig originalConfig = randomLearnToRankConfig();
+        assertThat(originalConfig, equalTo(LearnToRankConfigUpdate.EMPTY_PARAMS.apply(originalConfig)));
+        assertThat(
+            new LearnToRankConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(5).build(),
+            equalTo(new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(5).build().apply(originalConfig))
+        );
+        assertThat(
+            new LearnToRankConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(1).build(),
+            equalTo(new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(1).build().apply(originalConfig))
+        );
+
+        LearnToRankFeatureExtractorBuilder extractorBuilder = new LearnToRankConfigTests.TestValueExtractor("foo");
+        LearnToRankFeatureExtractorBuilder extractorBuilder2 = new LearnToRankConfigTests.TestValueExtractor("bar");
+
+        LearnToRankConfig config = new LearnToRankConfigUpdate.Builder().setNumTopFeatureImportanceValues(1)
+            .setFeatureExtractorBuilders(List.of(extractorBuilder2, extractorBuilder))
+            .build()
+            .apply(originalConfig);
+        assertThat(config.getNumTopFeatureImportanceValues(), equalTo(1));
+        assertThat(extractorBuilder2, is(in(config.getFeatureExtractorBuilders())));
+        assertThat(extractorBuilder, is(in(config.getFeatureExtractorBuilders())));
+    }
+
+    @Override
+    protected LearnToRankConfigUpdate createTestInstance() {
+        return randomLearnToRankConfigUpdate();
+    }
+
+    @Override
+    protected LearnToRankConfigUpdate mutateInstance(LearnToRankConfigUpdate instance) {
+        return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
+    }
+
+    @Override
+    protected Writeable.Reader<LearnToRankConfigUpdate> instanceReader() {
+        return LearnToRankConfigUpdate::new;
+    }
+
+    @Override
+    protected LearnToRankConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return LearnToRankConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected LearnToRankConfigUpdate mutateInstanceForVersion(LearnToRankConfigUpdate instance, TransportVersion version) {
+        return instance;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                LearnToRankFeatureExtractorBuilder.class,
+                LearnToRankConfigTests.TestValueExtractor.NAME,
+                LearnToRankConfigTests.TestValueExtractor::fromXContent
+            )
+        );
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+    @Override
+    protected NamedWriteableRegistry writableRegistry() {
+        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                LearnToRankFeatureExtractorBuilder.class,
+                LearnToRankConfigTests.TestValueExtractor.NAME.getPreferredName(),
+                LearnToRankConfigTests.TestValueExtractor::new
+            )
+        );
+        return new NamedWriteableRegistry(namedWriteables);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return writableRegistry();
+    }
+}

+ 1 - 1
x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java

@@ -33,7 +33,7 @@ public class MlRescorerIT extends ESRestTestCase {
                         "description": "super complex model for tests",
                         "input": {"field_names": ["cost", "product"]},
                         "inference_config": {
-                          "regression": {
+                          "learn_to_rank": {
                           }
                         },
                         "definition": {

+ 1 - 1
x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceRescorerIT.java

@@ -30,7 +30,7 @@ public class InferenceRescorerIT extends InferenceTestCase {
                         "description": "super complex model for tests",
                         "input": {"field_names": ["cost", "product"]},
                         "inference_config": {
-                          "regression": {
+                          "learn_to_rank": {
                           }
                         },
                         "definition": {

+ 9 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -190,6 +190,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNam
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@@ -1755,6 +1756,10 @@ public class MachineLearning extends Plugin
             )
         );
         namedXContent.addAll(new CorrelationNamedContentProvider().getNamedXContentParsers());
+        // LTR Combine with Inference named content provider when feature flag is removed
+        if (InferenceRescorerFeature.isEnabled()) {
+            namedXContent.addAll(new MlLTRNamedXContentProvider().getNamedXContentParsers());
+        }
         return namedXContent;
     }
 
@@ -1839,7 +1844,10 @@ public class MachineLearning extends Plugin
         namedWriteables.addAll(MlAutoscalingNamedWritableProvider.getNamedWriteables());
         namedWriteables.addAll(new CorrelationNamedContentProvider().getNamedWriteables());
         namedWriteables.addAll(new ChangePointNamedContentProvider().getNamedWriteables());
-
+        // LTR Combine with Inference named content provider when feature flag is removed
+        if (InferenceRescorerFeature.isEnabled()) {
+            namedWriteables.addAll(new MlLTRNamedXContentProvider().getNamedWriteables());
+        }
         return namedWriteables;
     }
 

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java

@@ -264,7 +264,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
 
         SetOnce<LocalModel> loadedModel = new SetOnce<>();
         BiConsumer<Client, ActionListener<?>> modelLoadAction = (client, listener) -> modelLoadingService.get()
-            .getModelForSearch(modelId, listener.delegateFailure((delegate, localModel) -> {
+            .getModelForAggregation(modelId, listener.delegateFailure((delegate, localModel) -> {
                 loadedModel.set(localModel);
 
                 boolean isLicensed = localModel.getLicenseLevel() == License.OperationMode.BASIC

+ 12 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

@@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.license.License;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -59,6 +60,7 @@ public class LocalModel implements Closeable {
     private final CircuitBreaker trainedModelCircuitBreaker;
     private final AtomicLong referenceCount;
     private final long cachedRamBytesUsed;
+    private final TrainedModelType trainedModelType;
 
     LocalModel(
         String modelId,
@@ -68,6 +70,7 @@ public class LocalModel implements Closeable {
         Map<String, String> defaultFieldMap,
         InferenceConfig modelInferenceConfig,
         License.OperationMode licenseLevel,
+        TrainedModelType trainedModelType,
         TrainedModelStatsService trainedModelStatsService,
         CircuitBreaker trainedModelCircuitBreaker
     ) {
@@ -85,6 +88,7 @@ public class LocalModel implements Closeable {
         this.licenseLevel = licenseLevel;
         this.trainedModelCircuitBreaker = trainedModelCircuitBreaker;
         this.referenceCount = new AtomicLong(1);
+        this.trainedModelType = trainedModelType;
     }
 
     long ramBytesUsed() {
@@ -94,6 +98,14 @@ public class LocalModel implements Closeable {
         return cachedRamBytesUsed;
     }
 
+    public InferenceConfig getInferenceConfig() {
+        return inferenceConfig;
+    }
+
+    TrainedModelType getTrainedModelType() {
+        return trainedModelType;
+    }
+
     public String getModelId() {
         return modelId;
     }

+ 141 - 35
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

@@ -26,6 +26,7 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.ingest.IngestMetadata;
@@ -36,6 +37,7 @@ import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -51,12 +53,14 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.util.ArrayDeque;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
@@ -110,11 +114,71 @@ public class ModelLoadingService implements ClusterStateListener {
         Setting.Property.NodeScope
     );
 
-    // The feature requesting the model
+    /**
+     * The cached model consumer. Various consumers dictate the model's usage and context
+     */
     public enum Consumer {
-        PIPELINE,
-        SEARCH,
-        INTERNAL
+        PIPELINE() {
+            @Override
+            public boolean inferenceConfigSupported(InferenceConfig config) {
+                return config == null || config.supportsIngestPipeline();
+            }
+
+            @Override
+            public String exceptionName() {
+                return "ingest";
+            }
+        },
+        SEARCH_AGGS() {
+            @Override
+            public boolean inferenceConfigSupported(InferenceConfig config) {
+                return config == null || config.supportsPipelineAggregation();
+            }
+
+            @Override
+            public String exceptionName() {
+                return "search(aggregation)";
+            }
+        },
+        SEARCH_RESCORER() {
+            @Override
+            public boolean inferenceConfigSupported(InferenceConfig config) {
+                // Null configs imply creation via target type. This is for BWC for very old models
+                // Consequently, if the config is null, we don't support LTR with them.
+                return config != null && config.supportsSearchRescorer();
+            }
+
+            @Override
+            public String exceptionName() {
+                return "search(rescorer)";
+            }
+        },
+        INTERNAL() {
+            @Override
+            public boolean inferenceConfigSupported(InferenceConfig config) {
+                return true;
+            }
+
+            @Override
+            public String exceptionName() {
+                return "internal";
+            }
+        };
+
+        /**
+         * @param config The inference config for the model. It may be null for very old Regression or classification models
+         * @return Is this configuration type supported within this cache context?
+         */
+        public abstract boolean inferenceConfigSupported(@Nullable InferenceConfig config);
+
+        /**
+         * @return The cache context name to use if an exception must be thrown due to the config not being supported
+         */
+        public abstract String exceptionName();
+
+        public boolean isAnyOf(Consumer... consumers) {
+            return Arrays.stream(consumers).anyMatch(c -> this == c);
+        }
     }
 
     private static class ModelAndConsumer {
@@ -219,13 +283,23 @@ public class ModelLoadingService implements ClusterStateListener {
     }
 
     /**
-     * Load the model for use by at search. Models requested by search are always cached.
+     * Load the model for use by at search through aggregations. Models requested by search are always cached.
+     *
+     * @param modelId  the model to get
+     * @param modelActionListener the listener to alert when the model has been retrieved
+     */
+    public void getModelForAggregation(String modelId, ActionListener<LocalModel> modelActionListener) {
+        getModel(modelId, Consumer.SEARCH_AGGS, null, modelActionListener);
+    }
+
+    /**
+     * Load the model for use by at search for rescoring. Models requested by search are always cached.
      *
      * @param modelId  the model to get
      * @param modelActionListener the listener to alert when the model has been retrieved
      */
-    public void getModelForSearch(String modelId, ActionListener<LocalModel> modelActionListener) {
-        getModel(modelId, Consumer.SEARCH, null, modelActionListener);
+    public void getModelForLearnToRank(String modelId, ActionListener<LocalModel> modelActionListener) {
+        getModel(modelId, Consumer.SEARCH_RESCORER, null, modelActionListener);
     }
 
     /**
@@ -259,6 +333,18 @@ public class ModelLoadingService implements ClusterStateListener {
         final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
         ModelAndConsumer cachedModel = localModelCache.get(modelId);
         if (cachedModel != null) {
+            // Even if the model is already cached, we don't want to use the model in an unsupported task
+            if (consumer.inferenceConfigSupported(cachedModel.model.getInferenceConfig()) == false) {
+                modelActionListener.onFailure(
+                    modelUnsupportedInUsageContext(
+                        modelId,
+                        cachedModel.model.getTrainedModelType(),
+                        cachedModel.model.getInferenceConfig(),
+                        consumer
+                    )
+                );
+                return;
+            }
             cachedModel.consumers.add(consumer);
             try {
                 cachedModel.model.acquire();
@@ -314,7 +400,6 @@ public class ModelLoadingService implements ClusterStateListener {
                     localModelToNotifyListener.set(cachedModel.model);
                     return true;
                 }
-
                 // Add the listener to the queue if the model is loading
                 Queue<ActionListener<LocalModel>> listeners = loadingListeners.computeIfPresent(
                     modelId,
@@ -330,7 +415,8 @@ public class ModelLoadingService implements ClusterStateListener {
 
                 // The model is not currently being loaded (indicated by listeners check above).
                 // So start a new load outside of the synchronized block.
-                if (Consumer.SEARCH != consumer && referencedModels.contains(modelId) == false) {
+                if (consumer.isAnyOf(Consumer.SEARCH_AGGS, Consumer.SEARCH_RESCORER) == false
+                    && referencedModels.contains(modelId) == false) {
                     // The model is requested by a pipeline but not referenced by any ingest pipelines.
                     // This means it is a simulate call and the model should not be cached
                     logger.trace(
@@ -368,19 +454,19 @@ public class ModelLoadingService implements ClusterStateListener {
         // We don't want to cancel the loading if only ONE of them stops listening or closes connection
         // TODO Is there a way to only signal a cancel if all the listener tasks cancel???
         provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null, ActionListener.wrap(trainedModelConfig -> {
-            if (trainedModelConfig.isAllocateOnly()) {
-                if (consumer == Consumer.SEARCH) {
-                    handleLoadFailure(
+            if (consumer.inferenceConfigSupported(trainedModelConfig.getInferenceConfig()) == false) {
+                handleLoadFailure(
+                    modelId,
+                    modelUnsupportedInUsageContext(
                         modelId,
-                        new ElasticsearchStatusException(
-                            "Trained model [{}] with type [{}] is currently not usable in search.",
-                            RestStatus.BAD_REQUEST,
-                            modelId,
-                            trainedModelConfig.getModelType()
-                        )
-                    );
-                    return;
-                }
+                        trainedModelConfig.getModelType(),
+                        trainedModelConfig.getInferenceConfig(),
+                        consumer
+                    )
+                );
+                return;
+            }
+            if (trainedModelConfig.isAllocateOnly()) {
                 handleLoadFailure(modelId, modelMustBeDeployedError(modelId));
                 return;
             }
@@ -419,19 +505,21 @@ public class ModelLoadingService implements ClusterStateListener {
         // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
         // by a simulated pipeline
         provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), parentTaskId, ActionListener.wrap(trainedModelConfig -> {
+            // If the model is used in an unsupported context, fail here
+            if (consumer.inferenceConfigSupported(trainedModelConfig.getInferenceConfig()) == false) {
+                handleLoadFailure(
+                    modelId,
+                    modelUnsupportedInUsageContext(
+                        modelId,
+                        trainedModelConfig.getModelType(),
+                        trainedModelConfig.getInferenceConfig(),
+                        consumer
+                    )
+                );
+                return;
+            }
             // If the model should be allocated, we should fail here
             if (trainedModelConfig.isAllocateOnly()) {
-                if (consumer == Consumer.SEARCH) {
-                    modelActionListener.onFailure(
-                        new ElasticsearchStatusException(
-                            "model [{}] with type [{}] is currently not usable in search.",
-                            RestStatus.BAD_REQUEST,
-                            modelId,
-                            trainedModelConfig.getModelType()
-                        )
-                    );
-                    return;
-                }
                 modelActionListener.onFailure(modelMustBeDeployedError(modelId));
                 return;
             }
@@ -457,6 +545,7 @@ public class ModelLoadingService implements ClusterStateListener {
                         trainedModelConfig.getDefaultFieldMap(),
                         inferenceConfig,
                         trainedModelConfig.getLicenseLevel(),
+                        trainedModelConfig.getModelType(),
                         modelStatsService,
                         trainedModelCircuitBreaker
                     )
@@ -500,7 +589,7 @@ public class ModelLoadingService implements ClusterStateListener {
         }
     }
 
-    private ElasticsearchStatusException modelMustBeDeployedError(String modelId) {
+    private static ElasticsearchStatusException modelMustBeDeployedError(String modelId) {
         return new ElasticsearchStatusException(
             "Model [{}] must be deployed to use. Please deploy with the start trained model deployment API.",
             RestStatus.BAD_REQUEST,
@@ -508,6 +597,22 @@ public class ModelLoadingService implements ClusterStateListener {
         );
     }
 
+    private static ElasticsearchStatusException modelUnsupportedInUsageContext(
+        String modelId,
+        TrainedModelType modelType,
+        InferenceConfig inferenceConfig,
+        Consumer consumer
+    ) {
+        return new ElasticsearchStatusException(
+            "Trained model [{}] with type [{}] and task [{}] is currently not usable in [{}].",
+            RestStatus.BAD_REQUEST,
+            modelId,
+            modelType,
+            Optional.ofNullable(inferenceConfig).map(InferenceConfig::getName).orElse("_unknown_"),
+            consumer.exceptionName()
+        );
+    }
+
     private void handleLoadSuccess(
         String modelId,
         Consumer consumer,
@@ -526,6 +631,7 @@ public class ModelLoadingService implements ClusterStateListener {
             trainedModelConfig.getDefaultFieldMap(),
             inferenceConfig,
             trainedModelConfig.getLicenseLevel(),
+            Optional.ofNullable(trainedModelConfig.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE),
             modelStatsService,
             trainedModelCircuitBreaker
         );
@@ -536,7 +642,7 @@ public class ModelLoadingService implements ClusterStateListener {
             // Also, if the consumer is a search consumer, we should always cache it
             if (referencedModels.contains(modelId)
                 || Sets.haveNonEmptyIntersection(modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()), referencedModels)
-                || consumer.equals(Consumer.SEARCH)) {
+                || consumer.equals(Consumer.SEARCH_AGGS)) {
                 try {
                     // The local model may already be in cache. If it is, we don't bother adding it to cache.
                     // If it isn't, we flip an `isLoaded` flag, and increment the model counter to make sure if it is evicted
@@ -699,7 +805,7 @@ public class ModelLoadingService implements ClusterStateListener {
                 );
                 if (oldModelAliasesNotReferenced && newModelAliasesNotReferenced && modelIsNotReferenced) {
                     ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
-                    if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
+                    if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH_AGGS) == false) {
                         logger.trace("[{} ({})] invalidated from cache", modelId, modelAliasOrId);
                         localModelCache.invalidate(modelId);
                     }

+ 22 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/FeatureExtractor.java

@@ -0,0 +1,22 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.rescorer;
+
+import org.apache.lucene.index.LeafReaderContext;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+public interface FeatureExtractor {
+    void setNextReader(LeafReaderContext segmentContext) throws IOException;
+
+    void addFeatures(Map<String, Object> featureMap, int docId) throws IOException;
+
+    List<String> featureNames();
+}

+ 64 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/FieldValueFeatureExtractor.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.rescorer;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.elasticsearch.index.mapper.MappedFieldType;
+import org.elasticsearch.index.mapper.ValueFetcher;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.search.lookup.SearchLookup;
+import org.elasticsearch.search.lookup.Source;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+public class FieldValueFeatureExtractor implements FeatureExtractor {
+
+    record FieldValueFetcher(String fieldName, ValueFetcher valueFetcher) {}
+
+    private LeafReaderContext segmentContext;
+    private final List<String> documentFieldNames;
+    private final List<FieldValueFetcher> valueFetcherList;
+    private final SearchLookup sourceLookup;
+
+    FieldValueFeatureExtractor(List<String> documentFieldNames, SearchExecutionContext executionContext) {
+        this.documentFieldNames = documentFieldNames;
+        this.valueFetcherList = documentFieldNames.stream().map(s -> {
+            MappedFieldType mappedFieldType = executionContext.getFieldType(s);
+            if (mappedFieldType != null) {
+                return new FieldValueFetcher(s, mappedFieldType.valueFetcher(executionContext, null));
+            }
+            return null;
+        }).filter(Objects::nonNull).toList();
+        this.sourceLookup = executionContext.lookup();
+    }
+
+    @Override
+    public void setNextReader(LeafReaderContext segmentContext) {
+        this.segmentContext = segmentContext;
+        for (FieldValueFetcher vf : valueFetcherList) {
+            vf.valueFetcher().setNextReader(segmentContext);
+        }
+    }
+
+    @Override
+    public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
+        Source source = sourceLookup.getSource(this.segmentContext, docId);
+        for (FieldValueFetcher vf : this.valueFetcherList) {
+            featureMap.put(vf.fieldName(), vf.valueFetcher().fetchValues(source, docId, new ArrayList<>()).get(0));
+        }
+    }
+
+    @Override
+    public List<String> featureNames() {
+        return documentFieldNames;
+    }
+}

+ 8 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java

@@ -15,11 +15,9 @@ import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.common.util.Maps;
-import org.elasticsearch.search.lookup.SearchLookup;
-import org.elasticsearch.search.lookup.Source;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.rescore.Rescorer;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 
 import java.io.IOException;
@@ -64,7 +62,6 @@ public class InferenceRescorer implements Rescorer {
         rescoreContext.setRescoredDocs(topNDocIDs);
         ScoreDoc[] hitsToRescore = topNFirstPass.scoreDocs;
         Arrays.sort(hitsToRescore, Comparator.comparingInt(a -> a.doc));
-        SearchLookup sourceLookup = ltrRescoreContext.executionContext.lookup();
         int hitUpto = 0;
         int readerUpto = -1;
         int endDoc = 0;
@@ -72,8 +69,9 @@ public class InferenceRescorer implements Rescorer {
         List<LeafReaderContext> leaves = ltrRescoreContext.executionContext.searcher().getIndexReader().leaves();
         LeafReaderContext currentSegment = null;
         boolean changedSegment = true;
+        List<FeatureExtractor> featureExtractors = ltrRescoreContext.buildFeatureExtractors();
         List<Map<String, Object>> docFeatures = new ArrayList<>(topNDocIDs.size());
-        int featureSize = ltrRescoreContext.valueFetcherList.size();
+        int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
         while (hitUpto < hitsToRescore.length) {
             final ScoreDoc hit = hitsToRescore[hitUpto];
             final int docID = hit.doc;
@@ -87,16 +85,15 @@ public class InferenceRescorer implements Rescorer {
             if (changedSegment) {
                 // We advanced to another segment and update our document value fetchers
                 docBase = currentSegment.docBase;
-                for (InferenceRescorerContext.FieldValueFetcher vf : ltrRescoreContext.valueFetcherList) {
-                    vf.valueFetcher().setNextReader(currentSegment);
+                for (FeatureExtractor featureExtractor : featureExtractors) {
+                    featureExtractor.setNextReader(currentSegment);
                 }
                 changedSegment = false;
             }
             int targetDoc = docID - docBase;
             Map<String, Object> features = Maps.newMapWithExpectedSize(featureSize);
-            Source source = sourceLookup.getSource(currentSegment, targetDoc);
-            for (InferenceRescorerContext.FieldValueFetcher vf : ltrRescoreContext.valueFetcherList) {
-                features.put(vf.fieldName(), vf.valueFetcher().fetchValues(source, targetDoc, new ArrayList<>()).get(0));
+            for (FeatureExtractor featureExtractor : featureExtractors) {
+                featureExtractor.addFeatures(features, targetDoc);
             }
             docFeatures.add(features);
             hitUpto++;
@@ -104,7 +101,7 @@ public class InferenceRescorer implements Rescorer {
         for (int i = 0; i < hitsToRescore.length; i++) {
             Map<String, Object> features = docFeatures.get(i);
             try {
-                hitsToRescore[i].score = ((Number) definition.infer(features, RegressionConfigUpdate.EMPTY_PARAMS).predictedValue())
+                hitsToRescore[i].score = ((Number) definition.infer(features, LearnToRankConfigUpdate.EMPTY_PARAMS).predictedValue())
                     .floatValue();
             } catch (Exception ex) {
                 logger.warn("Failure rescoring doc...", ex);

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilder.java

@@ -114,7 +114,7 @@ public class InferenceRescorerBuilder extends RescorerBuilder<InferenceRescorerB
                 throw new IllegalStateException("Model loading service must be available");
             }
             SetOnce<LocalModel> inferenceDefinitionSetOnce = new SetOnce<>();
-            ctx.registerAsyncAction((c, l) -> modelLoadingServiceSupplier.get().getModelForSearch(modelId, ActionListener.wrap(lm -> {
+            ctx.registerAsyncAction((c, l) -> modelLoadingServiceSupplier.get().getModelForLearnToRank(modelId, ActionListener.wrap(lm -> {
                 inferenceDefinitionSetOnce.set(lm);
                 l.onResponse(null);
             }, l::onFailure)));

+ 11 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerContext.java

@@ -7,23 +7,18 @@
 
 package org.elasticsearch.xpack.ml.inference.rescorer;
 
-import org.elasticsearch.index.mapper.MappedFieldType;
-import org.elasticsearch.index.mapper.ValueFetcher;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.rescore.Rescorer;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 
+import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 public class InferenceRescorerContext extends RescoreContext {
 
-    record FieldValueFetcher(String fieldName, ValueFetcher valueFetcher) {}
-
     final SearchExecutionContext executionContext;
     final LocalModel inferenceDefinition;
-    final List<FieldValueFetcher> valueFetcherList;
 
     /**
      * @param windowSize how many documents to rescore
@@ -40,16 +35,16 @@ public class InferenceRescorerContext extends RescoreContext {
         super(windowSize, rescorer);
         this.executionContext = executionContext;
         this.inferenceDefinition = inferenceDefinition;
-        if (inferenceDefinition != null) {
-            this.valueFetcherList = inferenceDefinition.inputFields().stream().map(s -> {
-                MappedFieldType mappedFieldType = executionContext.getFieldType(s);
-                if (mappedFieldType != null) {
-                    return new InferenceRescorerContext.FieldValueFetcher(s, mappedFieldType.valueFetcher(executionContext, null));
-                }
-                return null;
-            }).filter(Objects::nonNull).toList();
-        } else {
-            valueFetcherList = List.of();
+    }
+
+    List<FeatureExtractor> buildFeatureExtractors() {
+        assert this.inferenceDefinition != null;
+        List<FeatureExtractor> featureExtractors = new ArrayList<>();
+        if (this.inferenceDefinition.inputFields().isEmpty() == false) {
+            featureExtractors.add(
+                new FieldValueFeatureExtractor(new ArrayList<>(this.inferenceDefinition.inputFields()), this.executionContext)
+            );
         }
+        return featureExtractors;
     }
 }

+ 9 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.ingest.TestIngestDocument;
 import org.elasticsearch.license.License;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
@@ -81,6 +82,7 @@ public class LocalModelTests extends ESTestCase {
             Collections.singletonMap("field.foo", "field.foo.keyword"),
             ClassificationConfig.EMPTY_PARAMS,
             randomFrom(License.OperationMode.values()),
+            TrainedModelType.TREE_ENSEMBLE,
             modelStatsService,
             mock(CircuitBreaker.class)
         );
@@ -119,6 +121,7 @@ public class LocalModelTests extends ESTestCase {
             Collections.singletonMap("field.foo", "field.foo.keyword"),
             ClassificationConfig.EMPTY_PARAMS,
             License.OperationMode.PLATINUM,
+            TrainedModelType.TREE_ENSEMBLE,
             modelStatsService,
             mock(CircuitBreaker.class)
         );
@@ -171,6 +174,7 @@ public class LocalModelTests extends ESTestCase {
             Collections.singletonMap("field.foo", "field.foo.keyword"),
             ClassificationConfig.EMPTY_PARAMS,
             License.OperationMode.PLATINUM,
+            TrainedModelType.TREE_ENSEMBLE,
             modelStatsService,
             mock(CircuitBreaker.class)
         );
@@ -233,6 +237,7 @@ public class LocalModelTests extends ESTestCase {
             Collections.singletonMap("bar", "bar.keyword"),
             RegressionConfig.EMPTY_PARAMS,
             License.OperationMode.PLATINUM,
+            TrainedModelType.TREE_ENSEMBLE,
             modelStatsService,
             mock(CircuitBreaker.class)
         );
@@ -265,6 +270,7 @@ public class LocalModelTests extends ESTestCase {
             null,
             RegressionConfig.EMPTY_PARAMS,
             License.OperationMode.PLATINUM,
+            TrainedModelType.TREE_ENSEMBLE,
             modelStatsService,
             mock(CircuitBreaker.class)
         );
@@ -300,6 +306,7 @@ public class LocalModelTests extends ESTestCase {
             null,
             ClassificationConfig.EMPTY_PARAMS,
             License.OperationMode.PLATINUM,
+            TrainedModelType.TREE_ENSEMBLE,
             modelStatsService,
             mock(CircuitBreaker.class)
         );
@@ -359,6 +366,7 @@ public class LocalModelTests extends ESTestCase {
                 null,
                 ClassificationConfig.EMPTY_PARAMS,
                 License.OperationMode.PLATINUM,
+                TrainedModelType.TREE_ENSEMBLE,
                 modelStatsService,
                 breaker
             );
@@ -385,6 +393,7 @@ public class LocalModelTests extends ESTestCase {
                 null,
                 ClassificationConfig.EMPTY_PARAMS,
                 License.OperationMode.PLATINUM,
+                TrainedModelType.TREE_ENSEMBLE,
                 modelStatsService,
                 breaker
             );

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java

@@ -414,7 +414,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
 
         for (int i = 0; i < 3; i++) {
             PlainActionFuture<LocalModel> future = new PlainActionFuture<>();
-            modelLoadingService.getModelForSearch(modelId, future);
+            modelLoadingService.getModelForAggregation(modelId, future);
             assertThat(future.get(), is(not(nullValue())));
         }
 

+ 6 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorerBuilderRewriteTests.java

@@ -20,7 +20,6 @@ import org.elasticsearch.index.query.CoordinatorRewriteContext;
 import org.elasticsearch.index.query.DataRewriteContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.license.XPackLicenseState;
-import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.test.AbstractBuilderTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
@@ -31,7 +30,6 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.io.IOException;
 import java.util.List;
-import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
@@ -102,14 +100,13 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas
         when(localModel.inputFields()).thenReturn(inputFields);
         SearchExecutionContext context = createSearchExecutionContext();
         InferenceRescorerBuilder inferenceRescorerBuilder = new InferenceRescorerBuilder("test_model", localModel);
-        RescoreContext rescoreContext = inferenceRescorerBuilder.innerBuildContext(20, context);
-        assertTrue(rescoreContext instanceof InferenceRescorerContext);
+        InferenceRescorerContext rescoreContext = inferenceRescorerBuilder.innerBuildContext(20, context);
+        assertNotNull(rescoreContext);
         assertThat(rescoreContext.getWindowSize(), equalTo(20));
-        assertThat(((InferenceRescorerContext) rescoreContext).valueFetcherList, hasSize(2));
+        List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors();
+        assertThat(featureExtractors, hasSize(1));
         assertThat(
-            ((InferenceRescorerContext) rescoreContext).valueFetcherList.stream()
-                .map(InferenceRescorerContext.FieldValueFetcher::fieldName)
-                .collect(Collectors.toList()),
+            featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(),
             containsInAnyOrder(DOUBLE_FIELD_NAME, INT_FIELD_NAME)
         );
     }
@@ -134,7 +131,7 @@ public class InferenceRescorerBuilderRewriteTests extends AbstractBuilderTestCas
         }
 
         @Override
-        public void getModelForSearch(String modelId, ActionListener<LocalModel> modelActionListener) {
+        public void getModelForLearnToRank(String modelId, ActionListener<LocalModel> modelActionListener) {
             modelActionListener.onResponse(localModel());
         }
     }

+ 1 - 1
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_rescore.yml

@@ -11,7 +11,7 @@ setup:
             "description": "super complex model for tests",
             "input": {"field_names": ["cost", "product"]},
             "inference_config": {
-              "regression": {
+              "learn_to_rank": {
               }
             },
             "definition": {