|  | @@ -0,0 +1,402 @@
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 | 
	
		
			
				|  |  | + * or more contributor license agreements. Licensed under the Elastic License;
 | 
	
		
			
				|  |  | + * you may not use this file except in compliance with the Elastic License.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import org.elasticsearch.ElasticsearchException;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.io.stream.Writeable;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.settings.Settings;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.xcontent.XContentParser;
 | 
	
		
			
				|  |  | +import org.elasticsearch.search.SearchModule;
 | 
	
		
			
				|  |  | +import org.elasticsearch.test.AbstractSerializingTestCase;
 | 
	
		
			
				|  |  | +import org.elasticsearch.test.ESTestCase;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
 | 
	
		
			
				|  |  | +import org.junit.Before;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import java.io.IOException;
 | 
	
		
			
				|  |  | +import java.util.ArrayList;
 | 
	
		
			
				|  |  | +import java.util.Arrays;
 | 
	
		
			
				|  |  | +import java.util.Collections;
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  | +import java.util.Map;
 | 
	
		
			
				|  |  | +import java.util.function.Predicate;
 | 
	
		
			
				|  |  | +import java.util.stream.Collectors;
 | 
	
		
			
				|  |  | +import java.util.stream.IntStream;
 | 
	
		
			
				|  |  | +import java.util.stream.Stream;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.closeTo;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.equalTo;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private boolean lenient;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Before
 | 
	
		
			
				|  |  | +    public void chooseStrictOrLenient() {
 | 
	
		
			
				|  |  | +        lenient = randomBoolean();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected boolean supportsUnknownFields() {
 | 
	
		
			
				|  |  | +        return lenient;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected Predicate<String> getRandomFieldsExcludeFilter() {
 | 
	
		
			
				|  |  | +        return field -> !field.isEmpty();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected Ensemble doParseInstance(XContentParser parser) throws IOException {
 | 
	
		
			
				|  |  | +        return lenient ? Ensemble.fromXContentLenient(parser) : Ensemble.fromXContentStrict(parser);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public static Ensemble createRandom() {
 | 
	
		
			
				|  |  | +        int numberOfFeatures = randomIntBetween(1, 10);
 | 
	
		
			
				|  |  | +        List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
 | 
	
		
			
				|  |  | +        int numberOfModels = randomIntBetween(1, 10);
 | 
	
		
			
				|  |  | +        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
 | 
	
		
			
				|  |  | +            .limit(numberOfModels)
 | 
	
		
			
				|  |  | +            .collect(Collectors.toList());
 | 
	
		
			
				|  |  | +        List<Double> weights = randomBoolean() ?
 | 
	
		
			
				|  |  | +            null :
 | 
	
		
			
				|  |  | +            Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
 | 
	
		
			
				|  |  | +        OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
 | 
	
		
			
				|  |  | +        List<String> categoryLabels = null;
 | 
	
		
			
				|  |  | +        if (randomBoolean()) {
 | 
	
		
			
				|  |  | +            categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return new Ensemble(featureNames,
 | 
	
		
			
				|  |  | +            models,
 | 
	
		
			
				|  |  | +            outputAggregator,
 | 
	
		
			
				|  |  | +            randomFrom(TargetType.values()),
 | 
	
		
			
				|  |  | +            categoryLabels);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected Ensemble createTestInstance() {
 | 
	
		
			
				|  |  | +        return createRandom();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected Writeable.Reader<Ensemble> instanceReader() {
 | 
	
		
			
				|  |  | +        return Ensemble::new;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected NamedXContentRegistry xContentRegistry() {
 | 
	
		
			
				|  |  | +        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
 | 
	
		
			
				|  |  | +        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
 | 
	
		
			
				|  |  | +        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
 | 
	
		
			
				|  |  | +        return new NamedXContentRegistry(namedXContent);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    protected NamedWriteableRegistry getNamedWriteableRegistry() {
 | 
	
		
			
				|  |  | +        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
 | 
	
		
			
				|  |  | +        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
 | 
	
		
			
				|  |  | +        return new NamedWriteableRegistry(entries);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testEnsembleWithModelsThatHaveDifferentFeatureNames() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar", "baz", "farequote");
 | 
	
		
			
				|  |  | +        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
 | 
	
		
			
				|  |  | +            Ensemble.builder().setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6)))
 | 
	
		
			
				|  |  | +                .build()
 | 
	
		
			
				|  |  | +                .validate();
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +        assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ex = expectThrows(ElasticsearchException.class, () -> {
 | 
	
		
			
				|  |  | +            Ensemble.builder().setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6)))
 | 
	
		
			
				|  |  | +                .build()
 | 
	
		
			
				|  |  | +                .validate();
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +        assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar");
 | 
	
		
			
				|  |  | +        int numberOfModels = 5;
 | 
	
		
			
				|  |  | +        List<Double> weights = new ArrayList<>(numberOfModels + 2);
 | 
	
		
			
				|  |  | +        for (int i = 0; i < numberOfModels + 2; i++) {
 | 
	
		
			
				|  |  | +            weights.add(randomDouble());
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        List<TrainedModel> models = new ArrayList<>(numberOfModels);
 | 
	
		
			
				|  |  | +        for (int i = 0; i < numberOfModels; i++) {
 | 
	
		
			
				|  |  | +            models.add(TreeTests.buildRandomTree(featureNames, 6));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
 | 
	
		
			
				|  |  | +            Ensemble.builder()
 | 
	
		
			
				|  |  | +                .setTrainedModels(models)
 | 
	
		
			
				|  |  | +                .setOutputAggregator(outputAggregator)
 | 
	
		
			
				|  |  | +                .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                .build()
 | 
	
		
			
				|  |  | +                .validate();
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +        assertThat(ex.getMessage(), equalTo("[aggregate_output] expects value array of size [7] but number of models is [5]"));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testEnsembleWithInvalidModel() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar");
 | 
	
		
			
				|  |  | +        expectThrows(ElasticsearchException.class, () -> {
 | 
	
		
			
				|  |  | +            Ensemble.builder()
 | 
	
		
			
				|  |  | +                .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                .setTrainedModels(Arrays.asList(
 | 
	
		
			
				|  |  | +                // Tree with loop
 | 
	
		
			
				|  |  | +                Tree.builder()
 | 
	
		
			
				|  |  | +                    .setNodes(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                    .setLeftChild(1)
 | 
	
		
			
				|  |  | +                    .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                    .setThreshold(randomDouble()),
 | 
	
		
			
				|  |  | +                TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                    .setLeftChild(0)
 | 
	
		
			
				|  |  | +                    .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                    .setThreshold(randomDouble()))
 | 
	
		
			
				|  |  | +                    .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                    .build()))
 | 
	
		
			
				|  |  | +                .build()
 | 
	
		
			
				|  |  | +                .validate();
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testEnsembleWithTargetTypeAndLabelsMismatch() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar");
 | 
	
		
			
				|  |  | +        String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
 | 
	
		
			
				|  |  | +        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
 | 
	
		
			
				|  |  | +            Ensemble.builder()
 | 
	
		
			
				|  |  | +                .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                .setTrainedModels(Arrays.asList(
 | 
	
		
			
				|  |  | +                    Tree.builder()
 | 
	
		
			
				|  |  | +                        .setNodes(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                                .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                                .setThreshold(randomDouble()))
 | 
	
		
			
				|  |  | +                        .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                        .build()))
 | 
	
		
			
				|  |  | +                .setClassificationLabels(Arrays.asList("label1", "label2"))
 | 
	
		
			
				|  |  | +                .build()
 | 
	
		
			
				|  |  | +                .validate();
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +        assertThat(ex.getMessage(), equalTo(msg));
 | 
	
		
			
				|  |  | +        ex = expectThrows(ElasticsearchException.class, () -> {
 | 
	
		
			
				|  |  | +            Ensemble.builder()
 | 
	
		
			
				|  |  | +                .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                .setTrainedModels(Arrays.asList(
 | 
	
		
			
				|  |  | +                    Tree.builder()
 | 
	
		
			
				|  |  | +                        .setNodes(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                            .setLeftChild(1)
 | 
	
		
			
				|  |  | +                            .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                            .setThreshold(randomDouble()))
 | 
	
		
			
				|  |  | +                        .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +                        .build()))
 | 
	
		
			
				|  |  | +                .setTargetType(TargetType.CLASSIFICATION)
 | 
	
		
			
				|  |  | +                .build()
 | 
	
		
			
				|  |  | +                .validate();
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +        assertThat(ex.getMessage(), equalTo(msg));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testClassificationProbability() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar");
 | 
	
		
			
				|  |  | +        Tree tree1 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(0)
 | 
	
		
			
				|  |  | +                .setThreshold(0.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(1.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2)
 | 
	
		
			
				|  |  | +                .setThreshold(0.8)
 | 
	
		
			
				|  |  | +                .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                .setLeftChild(3)
 | 
	
		
			
				|  |  | +                .setRightChild(4))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(3).setLeafValue(0.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
 | 
	
		
			
				|  |  | +        Tree tree2 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(0)
 | 
	
		
			
				|  |  | +                .setThreshold(0.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(0.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2).setLeafValue(1.0))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +        Tree tree3 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                .setThreshold(1.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(1.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2).setLeafValue(0.0))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +        Ensemble ensemble = Ensemble.builder()
 | 
	
		
			
				|  |  | +            .setTargetType(TargetType.CLASSIFICATION)
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
 | 
	
		
			
				|  |  | +            .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        List<Double> featureVector = Arrays.asList(0.4, 0.0);
 | 
	
		
			
				|  |  | +        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        List<Double> expected = Arrays.asList(0.231475216, 0.768524783);
 | 
	
		
			
				|  |  | +        double eps = 0.000001;
 | 
	
		
			
				|  |  | +        List<Double> probabilities = ensemble.classificationProbability(featureMap);
 | 
	
		
			
				|  |  | +        for(int i = 0; i < expected.size(); i++) {
 | 
	
		
			
				|  |  | +            assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(2.0, 0.7);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        expected = Arrays.asList(0.3100255188, 0.689974481);
 | 
	
		
			
				|  |  | +        probabilities = ensemble.classificationProbability(featureMap);
 | 
	
		
			
				|  |  | +        for(int i = 0; i < expected.size(); i++) {
 | 
	
		
			
				|  |  | +            assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(0.0, 1.0);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        expected = Arrays.asList(0.231475216, 0.768524783);
 | 
	
		
			
				|  |  | +        probabilities = ensemble.classificationProbability(featureMap);
 | 
	
		
			
				|  |  | +        for(int i = 0; i < expected.size(); i++) {
 | 
	
		
			
				|  |  | +            assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testClassificationInference() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar");
 | 
	
		
			
				|  |  | +        Tree tree1 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(0)
 | 
	
		
			
				|  |  | +                .setThreshold(0.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(1.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2)
 | 
	
		
			
				|  |  | +                .setThreshold(0.8)
 | 
	
		
			
				|  |  | +                .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                .setLeftChild(3)
 | 
	
		
			
				|  |  | +                .setRightChild(4))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(3).setLeafValue(0.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
 | 
	
		
			
				|  |  | +        Tree tree2 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(0)
 | 
	
		
			
				|  |  | +                .setThreshold(0.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(0.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2).setLeafValue(1.0))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +        Tree tree3 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                .setThreshold(1.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(1.0))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2).setLeafValue(0.0))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +        Ensemble ensemble = Ensemble.builder()
 | 
	
		
			
				|  |  | +            .setTargetType(TargetType.CLASSIFICATION)
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
 | 
	
		
			
				|  |  | +            .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        List<Double> featureVector = Arrays.asList(0.4, 0.0);
 | 
	
		
			
				|  |  | +        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(2.0, 0.7);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(0.0, 1.0);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testRegressionInference() {
 | 
	
		
			
				|  |  | +        List<String> featureNames = Arrays.asList("foo", "bar");
 | 
	
		
			
				|  |  | +        Tree tree1 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(0)
 | 
	
		
			
				|  |  | +                .setThreshold(0.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(0.3))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2)
 | 
	
		
			
				|  |  | +                .setThreshold(0.8)
 | 
	
		
			
				|  |  | +                .setSplitFeature(1)
 | 
	
		
			
				|  |  | +                .setLeftChild(3)
 | 
	
		
			
				|  |  | +                .setRightChild(4))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(3).setLeafValue(0.1))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
 | 
	
		
			
				|  |  | +        Tree tree2 = Tree.builder()
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setRoot(TreeNode.builder(0)
 | 
	
		
			
				|  |  | +                .setLeftChild(1)
 | 
	
		
			
				|  |  | +                .setRightChild(2)
 | 
	
		
			
				|  |  | +                .setSplitFeature(0)
 | 
	
		
			
				|  |  | +                .setThreshold(0.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(1).setLeafValue(1.5))
 | 
	
		
			
				|  |  | +            .addNode(TreeNode.builder(2).setLeafValue(0.9))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +        Ensemble ensemble = Ensemble.builder()
 | 
	
		
			
				|  |  | +            .setTargetType(TargetType.REGRESSION)
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setTrainedModels(Arrays.asList(tree1, tree2))
 | 
	
		
			
				|  |  | +            .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5)))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        List<Double> featureVector = Arrays.asList(0.4, 0.0);
 | 
	
		
			
				|  |  | +        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(0.9, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(2.0, 0.7);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(0.5, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Test with NO aggregator supplied, verifies default behavior of non-weighted sum
 | 
	
		
			
				|  |  | +        ensemble = Ensemble.builder()
 | 
	
		
			
				|  |  | +            .setTargetType(TargetType.REGRESSION)
 | 
	
		
			
				|  |  | +            .setFeatureNames(featureNames)
 | 
	
		
			
				|  |  | +            .setTrainedModels(Arrays.asList(tree1, tree2))
 | 
	
		
			
				|  |  | +            .build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(0.4, 0.0);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        featureVector = Arrays.asList(2.0, 0.7);
 | 
	
		
			
				|  |  | +        featureMap = zipObjMap(featureNames, featureVector);
 | 
	
		
			
				|  |  | +        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
 | 
	
		
			
				|  |  | +        return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 |