Browse Source

Add Bounded Window to Inference Models for Rescoring to Ensure Positive Score Range (#125694)

* apply bounded window inference model

* linting

* add unit tests

* [CI] Auto commit changes from spotless

* add additional tests

* remove unused constructor

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Mark J. Hoy 6 months ago
parent
commit
e77bf808ab

+ 5 - 0
docs/changelog/125694.yaml

@@ -0,0 +1,5 @@
+pr: 125694
+summary: LTR score bounding
+area: Ranking
+type: bug
+issues: []

+ 14 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java

@@ -0,0 +1,14 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
+
+public interface BoundedInferenceModel extends InferenceModel {
+    double getMinPredictedValue();
+
+    double getMaxPredictedValue();
+}

+ 123 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModel.java

@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
+
+import org.elasticsearch.common.logging.LoggerMessageFormat;
+import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+
+import java.util.Map;
+
+public class BoundedWindowInferenceModel implements BoundedInferenceModel {
+    public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;
+
+    private final BoundedInferenceModel model;
+    private final double minPredictedValue;
+    private final double maxPredictedValue;
+    private final double adjustmentValue;
+
+    public BoundedWindowInferenceModel(BoundedInferenceModel model) {
+        this.model = model;
+        this.minPredictedValue = model.getMinPredictedValue();
+        this.maxPredictedValue = model.getMaxPredictedValue();
+
+        if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) {
+            this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue;
+        } else {
+            this.adjustmentValue = 0.0;
+        }
+    }
+
+    @Override
+    public String[] getFeatureNames() {
+        return model.getFeatureNames();
+    }
+
+    @Override
+    public TargetType targetType() {
+        return model.targetType();
+    }
+
+    @Override
+    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
+        return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap));
+    }
+
+    @Override
+    public InferenceResults infer(double[] features, InferenceConfig config) {
+        return boundInferenceResultScores(model.infer(features, config));
+    }
+
+    @Override
+    public boolean supportsFeatureImportance() {
+        return model.supportsFeatureImportance();
+    }
+
+    @Override
+    public String getName() {
+        return "bounded_window[" + model.getName() + "]";
+    }
+
+    @Override
+    public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
+        model.rewriteFeatureIndices(newFeatureIndexMapping);
+    }
+
+    @Override
+    public long ramBytesUsed() {
+        return model.ramBytesUsed();
+    }
+
+    @Override
+    public double getMinPredictedValue() {
+        return minPredictedValue;
+    }
+
+    @Override
+    public double getMaxPredictedValue() {
+        return maxPredictedValue;
+    }
+
+    private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) {
+        // if the min value < the default minimum, slide the values up by the adjustment value
+        if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
+            double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
+
+            predictedValue += this.adjustmentValue;
+
+            return new RegressionInferenceResults(
+                predictedValue,
+                inferenceResult.getResultsField(),
+                ((RegressionInferenceResults) inferenceResult).getFeatureImportance()
+            );
+        }
+
+        throw new IllegalStateException(
+            LoggerMessageFormat.format(
+                "Model used within a {} should return a {} but got {} instead",
+                BoundedWindowInferenceModel.class.getSimpleName(),
+                RegressionInferenceResults.class.getSimpleName(),
+                inferenceResult.getClass().getSimpleName()
+            )
+        );
+    }
+
+    @Override
+    public String toString() {
+        return "BoundedWindowInferenceModel{"
+            + "model="
+            + model
+            + ", minPredictedValue="
+            + getMinPredictedValue()
+            + ", maxPredictedValue="
+            + getMaxPredictedValue()
+            + '}';
+    }
+}

+ 57 - 17
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.util.CachedSupplier;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.inference.InferenceResults;
@@ -36,6 +37,7 @@ import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -52,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
 
-public class EnsembleInferenceModel implements InferenceModel {
+public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {
 
     public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
     private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
@@ -97,6 +99,7 @@ public class EnsembleInferenceModel implements InferenceModel {
     private final List<String> classificationLabels;
     private final double[] classificationWeights;
     private volatile boolean preparedForInference = false;
+    private final Supplier<double[]> predictedValuesBoundariesSupplier;
 
     private EnsembleInferenceModel(
         List<InferenceModel> models,
@@ -112,6 +115,7 @@ public class EnsembleInferenceModel implements InferenceModel {
         this.classificationWeights = classificationWeights == null
             ? null
             : classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
+        this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
     }
 
     @Override
@@ -328,21 +332,57 @@ public class EnsembleInferenceModel implements InferenceModel {
 
     @Override
     public String toString() {
-        return "EnsembleInferenceModel{"
-            + "featureNames="
-            + Arrays.toString(featureNames)
-            + ", models="
-            + models
-            + ", outputAggregator="
-            + outputAggregator
-            + ", targetType="
-            + targetType
-            + ", classificationLabels="
-            + classificationLabels
-            + ", classificationWeights="
-            + Arrays.toString(classificationWeights)
-            + ", preparedForInference="
-            + preparedForInference
-            + '}';
+        StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
+
+        builder.append("featureNames=")
+            .append(Arrays.toString(featureNames))
+            .append(", models=")
+            .append(models)
+            .append(", outputAggregator=")
+            .append(outputAggregator)
+            .append(", targetType=")
+            .append(targetType);
+
+        if (targetType == TargetType.CLASSIFICATION) {
+            builder.append(", classificationLabels=")
+                .append(classificationLabels)
+                .append(", classificationWeights=")
+                .append(Arrays.toString(classificationWeights));
+        } else if (targetType == TargetType.REGRESSION) {
+            builder.append(", minPredictedValue=")
+                .append(getMinPredictedValue())
+                .append(", maxPredictedValue=")
+                .append(getMaxPredictedValue());
+        }
+
+        builder.append(", preparedForInference=").append(preparedForInference);
+
+        return builder.append('}').toString();
+    }
+
+    @Override
+    public double getMinPredictedValue() {
+        return this.predictedValuesBoundariesSupplier.get()[0];
+    }
+
+    @Override
+    public double getMaxPredictedValue() {
+        return this.predictedValuesBoundariesSupplier.get()[1];
+    }
+
+    private double[] initModelBoundaries() {
+        double[] modelsMinBoundaries = new double[models.size()];
+        double[] modelsMaxBoundaries = new double[models.size()];
+        int i = 0;
+        for (InferenceModel model : models) {
+            if (model instanceof BoundedInferenceModel boundedInferenceModel) {
+                modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
+                modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
+            } else {
+                throw new IllegalStateException("All submodels have to be bounded");
+            }
+        }
+
+        return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) };
     }
 }

+ 10 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java

@@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -79,13 +80,21 @@ public class InferenceDefinition {
 
     public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
         preProcess(fields);
+
+        InferenceModel inferenceModel = trainedModel;
+
+        if (config instanceof LearningToRankConfig) {
+            assert trainedModel instanceof BoundedInferenceModel;
+            inferenceModel = new BoundedWindowInferenceModel((BoundedInferenceModel) trainedModel);
+        }
+
         if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
             throw ExceptionsHelper.badRequestException(
                 "Feature importance is not supported for the configured model of type [{}]",
                 trainedModel.getName()
             );
         }
-        return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
+        return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
     }
 
     public TargetType getTargetType() {

+ 54 - 30
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java

@@ -58,7 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD;
 
-public class TreeInferenceModel implements InferenceModel {
+public class TreeInferenceModel implements InferenceModel, BoundedInferenceModel {
 
     private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
     public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
@@ -90,7 +90,7 @@ public class TreeInferenceModel implements InferenceModel {
     private String[] featureNames;
     private final TargetType targetType;
     private List<String> classificationLabels;
-    private final double highOrderCategory;
+    private final double[] leafBoundaries;
     private final int maxDepth;
     private final int leafSize;
     private volatile boolean preparedForInference = false;
@@ -108,7 +108,7 @@ public class TreeInferenceModel implements InferenceModel {
         this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
         this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
         this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
-        this.highOrderCategory = maxLeafValue();
+        this.leafBoundaries = getLeafBoundaries();
         int leafSize = 1;
         for (Node node : this.nodes) {
             if (node instanceof LeafNode leafNode) {
@@ -218,7 +218,7 @@ public class TreeInferenceModel implements InferenceModel {
         }
         // If we are classification, we should assume that the inference return value is whole.
         assert inferenceValue[0] == Math.rint(inferenceValue[0]);
-        double maxCategory = this.highOrderCategory;
+        double maxCategory = getHighOrderCategory();
         // If we are classification, we should assume that the largest leaf value is whole.
         assert maxCategory == Math.rint(maxCategory);
         double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
@@ -366,21 +366,20 @@ public class TreeInferenceModel implements InferenceModel {
         return size;
     }
 
-    private double maxLeafValue() {
-        if (targetType != TargetType.CLASSIFICATION) {
-            return Double.NaN;
-        }
-        double max = 0.0;
+    private double[] getLeafBoundaries() {
+        double[] bounds = new double[] { Double.MAX_VALUE, Double.MIN_VALUE };
+
         for (Node node : this.nodes) {
             if (node instanceof LeafNode leafNode) {
                 if (leafNode.leafValue.length > 1) {
-                    return leafNode.leafValue.length;
+                    return new double[] { 0, leafNode.leafValue.length };
                 } else {
-                    max = Math.max(leafNode.leafValue[0], max);
+                    bounds[0] = Math.min(leafNode.leafValue[0], bounds[0]);
+                    bounds[1] = Math.max(leafNode.leafValue[0], bounds[1]);
                 }
             }
         }
-        return max;
+        return bounds;
     }
 
     public Node[] getNodes() {
@@ -389,24 +388,35 @@ public class TreeInferenceModel implements InferenceModel {
 
     @Override
     public String toString() {
-        return "TreeInferenceModel{"
-            + "nodes="
-            + Arrays.toString(nodes)
-            + ", featureNames="
-            + Arrays.toString(featureNames)
-            + ", targetType="
-            + targetType
-            + ", classificationLabels="
-            + classificationLabels
-            + ", highOrderCategory="
-            + highOrderCategory
-            + ", maxDepth="
-            + maxDepth
-            + ", leafSize="
-            + leafSize
-            + ", preparedForInference="
-            + preparedForInference
-            + '}';
+        StringBuilder builder = new StringBuilder("TreeInferenceModel{");
+
+        builder.append("nodes=")
+            .append(Arrays.toString(nodes))
+            .append(", featureNames=")
+            .append(Arrays.toString(featureNames))
+            .append(", targetType=")
+            .append(targetType);
+
+        if (targetType == TargetType.CLASSIFICATION) {
+            builder.append(", classificationLabels=")
+                .append(classificationLabels)
+                .append(", highOrderCategory=")
+                .append(getHighOrderCategory());
+        } else if (targetType == TargetType.REGRESSION) {
+            builder.append(", minPredictedValue=")
+                .append(getMinPredictedValue())
+                .append(", maxPredictedValue=")
+                .append(getMaxPredictedValue());
+        }
+
+        builder.append(", maxDepth=")
+            .append(maxDepth)
+            .append(", leafSize=")
+            .append(leafSize)
+            .append(", preparedForInference=")
+            .append(preparedForInference);
+
+        return builder.append('}').toString();
     }
 
     private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
@@ -420,6 +430,20 @@ public class TreeInferenceModel implements InferenceModel {
         return Math.max(depthLeft, depthRight) + 1;
     }
 
+    @Override
+    public double getMinPredictedValue() {
+        return leafBoundaries[0];
+    }
+
+    @Override
+    public double getMaxPredictedValue() {
+        return leafBoundaries[1];
+    }
+
+    private double getHighOrderCategory() {
+        return getMaxPredictedValue();
+    }
+
     static class NodeBuilder {
 
         private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(

+ 116 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModelTests.java

@@ -0,0 +1,116 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
+import static org.hamcrest.Matchers.equalTo;
+
+public class BoundedWindowInferenceModelTests extends ESTestCase {
+
+    private static final List<String> featureNames = Arrays.asList("foo", "bar");
+
+    public void testBoundsSetting() throws IOException {
+        BoundedWindowInferenceModel testModel = getModel(-2.0, 5.2, 10.5);
+        assertThat(testModel.getMinPredictedValue(), equalTo(-2.0));
+        assertThat(testModel.getMaxPredictedValue(), equalTo(10.5));
+    }
+
+    public void testInferenceScoresWithoutAdjustment() throws IOException {
+        BoundedWindowInferenceModel testModel = getModel(1.0, 5.2, 10.5);
+
+        List<Double> featureVector = Arrays.asList(0.4, 0.0);
+        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
+        Double lowResultValue = ((SingleValueInferenceResults) testModel.infer(
+            featureMap,
+            RegressionConfig.EMPTY_PARAMS,
+            Collections.emptyMap()
+        )).value();
+        assertThat(lowResultValue, equalTo(1.0));
+
+        featureVector = Arrays.asList(12.0, 0.0);
+        featureMap = zipObjMap(featureNames, featureVector);
+        Double highResultValue = ((SingleValueInferenceResults) testModel.infer(
+            featureMap,
+            RegressionConfig.EMPTY_PARAMS,
+            Collections.emptyMap()
+        )).value();
+        assertThat(highResultValue, equalTo(10.5));
+
+        double[] featureArray = new double[2];
+        featureArray[0] = 12.0;
+        featureArray[1] = 0.0;
+        Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS))
+            .value();
+        assertThat(highResultValueFromFeatures, equalTo(10.5));
+    }
+
+    public void testInferenceScoresWithAdjustment() throws IOException {
+        BoundedWindowInferenceModel testModel = getModel(-5.0, 1.2, 6.5);
+
+        List<Double> featureVector = Arrays.asList(-10.0, 0.0);
+        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
+        Double lowResultValue = ((SingleValueInferenceResults) testModel.infer(
+            featureMap,
+            RegressionConfig.EMPTY_PARAMS,
+            Collections.emptyMap()
+        )).value();
+        assertThat(lowResultValue, equalTo(0.0));
+
+        featureVector = Arrays.asList(12.0, 0.0);
+        featureMap = zipObjMap(featureNames, featureVector);
+        Double highResultValue = ((SingleValueInferenceResults) testModel.infer(
+            featureMap,
+            RegressionConfig.EMPTY_PARAMS,
+            Collections.emptyMap()
+        )).value();
+        assertThat(highResultValue, equalTo(11.5));
+
+        double[] featureArray = new double[2];
+        featureArray[0] = 12.0;
+        featureArray[1] = 0.0;
+        Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS))
+            .value();
+        assertThat(highResultValueFromFeatures, equalTo(11.5));
+    }
+
+    private BoundedWindowInferenceModel getModel(double lowerBoundValue, double midValue, double upperBoundValue) throws IOException {
+        Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
+        TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
+        builder.addLeaf(rootNode.getRightChild(), upperBoundValue);
+        TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
+        builder.addLeaf(leftChildNode.getLeftChild(), lowerBoundValue);
+        builder.addLeaf(leftChildNode.getRightChild(), midValue);
+
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree treeObject = builder.setFeatureNames(featureNames).build();
+        TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
+        tree.rewriteFeatureIndices(Collections.emptyMap());
+
+        return new BoundedWindowInferenceModel(tree);
+    }
+
+    private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
+        return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
+    }
+
+}

+ 35 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java

@@ -39,6 +39,7 @@ import java.util.stream.IntStream;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
 import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
@@ -537,6 +538,40 @@ public class EnsembleInferenceModelTests extends ESTestCase {
         assertThat(featureImportance[1][0], closeTo(0.1451914, eps));
     }
 
+    public void testMinAndMaxBoundaries() throws IOException {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(0.3))
+            .addNode(TreeNode.builder(2).setThreshold(0.8).setSplitFeature(1).setLeftChild(3).setRightChild(4))
+            .addNode(TreeNode.builder(3).setLeafValue(0.1))
+            .addNode(TreeNode.builder(4).setLeafValue(0.2))
+            .build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(1.5))
+            .addNode(TreeNode.builder(2).setLeafValue(0.9))
+            .build();
+        Ensemble ensembleObject = Ensemble.builder()
+            .setTargetType(TargetType.REGRESSION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2))
+            .setOutputAggregator(new WeightedSum(new double[] { 0.5, 0.5 }))
+            .build();
+
+        EnsembleInferenceModel ensemble = deserializeFromTrainedModel(
+            ensembleObject,
+            xContentRegistry(),
+            EnsembleInferenceModel::fromXContent
+        );
+        ensemble.rewriteFeatureIndices(Collections.emptyMap());
+
+        assertThat(ensemble.getMinPredictedValue(), equalTo(1.0));
+        assertThat(ensemble.getMaxPredictedValue(), equalTo(1.8));
+    }
+
     private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
         return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
     }

+ 40 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java

@@ -13,6 +13,8 @@ import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.core.Strings;
+import org.elasticsearch.core.Tuple;
+import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -25,17 +27,26 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 
 import java.io.IOException;
 import java.text.ParseException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.ENSEMBLE_MODEL;
 import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.TREE_MODEL;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 
@@ -176,6 +187,35 @@ public class InferenceDefinitionTests extends ESTestCase {
         }
     }
 
+    public void testWithLearningToRankConfiguration() throws IOException {
+        Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
+        TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
+        builder.addLeaf(rootNode.getRightChild(), -2.0);
+        TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
+        builder.addLeaf(leftChildNode.getLeftChild(), 0.2);
+        builder.addLeaf(leftChildNode.getRightChild(), 1.5);
+
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree treeObject = builder.setFeatureNames(featureNames).build();
+        TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
+        tree.rewriteFeatureIndices(Collections.emptyMap());
+
+        BoundedWindowInferenceModel testModel = new BoundedWindowInferenceModel(tree);
+
+        InferenceDefinition definition = new InferenceDefinition(testModel, null);
+        LearningToRankConfig config = new LearningToRankConfig(
+            randomBoolean() ? null : randomIntBetween(0, 10),
+            randomBoolean()
+                ? null
+                : Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList()),
+            randomBoolean() ? null : randomMap(0, 10, () -> Tuple.tuple(randomIdentifier(), randomIdentifier()))
+        );
+
+        InferenceResults results = definition.infer(Map.of("foo", 1.0, "bar", 0.0), config);
+
+        assertThat(results.predictedValue(), equalTo(2.0));
+    }
+
     public static String getClassificationDefinition(boolean customPreprocessor) {
         return Strings.format("""
             {

+ 17 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java

@@ -284,6 +284,23 @@ public class TreeInferenceModelTests extends ESTestCase {
         assertThat(featureImportance[1][0], closeTo(2.5, eps));
     }
 
+    public void testMinAndMaxBoundaries() throws IOException {
+        Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
+        TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
+        builder.addLeaf(rootNode.getRightChild(), 0.3);
+        TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
+        builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
+        builder.addLeaf(leftChildNode.getRightChild(), 0.2);
+
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree treeObject = builder.setFeatureNames(featureNames).build();
+        TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
+        tree.rewriteFeatureIndices(Collections.emptyMap());
+
+        assertThat(tree.getMinPredictedValue(), equalTo(0.1));
+        assertThat(tree.getMaxPredictedValue(), equalTo(0.3));
+    }
+
     private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
         return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
     }