Browse Source

[ML][HLRC] Add data frame analytics regression analysis (#46024)

Dimitris Athanasiou 6 years ago
parent
commit
eab64250eb

+ 5 - 1
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java

@@ -32,6 +32,10 @@ public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentPr
             new NamedXContentRegistry.Entry(
                 DataFrameAnalysis.class,
                 OutlierDetection.NAME,
-                (p, c) -> OutlierDetection.fromXContent(p)));
+                (p, c) -> OutlierDetection.fromXContent(p)),
+            new NamedXContentRegistry.Entry(
+                DataFrameAnalysis.class,
+                Regression.NAME,
+                (p, c) -> Regression.fromXContent(p)));
     }
 }

+ 242 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

@@ -0,0 +1,242 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class Regression implements DataFrameAnalysis {
+
+    public static Regression fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public static Builder builder(String dependentVariable) {
+        return new Builder(dependentVariable);
+    }
+
+    public static final ParseField NAME = new ParseField("regression");
+
+    static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
+    static final ParseField LAMBDA = new ParseField("lambda");
+    static final ParseField GAMMA = new ParseField("gamma");
+    static final ParseField ETA = new ParseField("eta");
+    static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
+    static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
+    static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
+    static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
+
+    private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
+        a -> new Regression(
+            (String) a[0],
+            (Double) a[1],
+            (Double) a[2],
+            (Double) a[3],
+            (Integer) a[4],
+            (Double) a[5],
+            (String) a[6],
+            (Double) a[7]));
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
+    }
+
+    private final String dependentVariable;
+    private final Double lambda;
+    private final Double gamma;
+    private final Double eta;
+    private final Integer maximumNumberTrees;
+    private final Double featureBagFraction;
+    private final String predictionFieldName;
+    private final Double trainingPercent;
+
+    private Regression(String dependentVariable,  @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
+                       @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
+                       @Nullable Double trainingPercent) {
+        this.dependentVariable = Objects.requireNonNull(dependentVariable);
+        this.lambda = lambda;
+        this.gamma = gamma;
+        this.eta = eta;
+        this.maximumNumberTrees = maximumNumberTrees;
+        this.featureBagFraction = featureBagFraction;
+        this.predictionFieldName = predictionFieldName;
+        this.trainingPercent = trainingPercent;
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    public String getDependentVariable() {
+        return dependentVariable;
+    }
+
+    public Double getLambda() {
+        return lambda;
+    }
+
+    public Double getGamma() {
+        return gamma;
+    }
+
+    public Double getEta() {
+        return eta;
+    }
+
+    public Integer getMaximumNumberTrees() {
+        return maximumNumberTrees;
+    }
+
+    public Double getFeatureBagFraction() {
+        return featureBagFraction;
+    }
+
+    public String getPredictionFieldName() {
+        return predictionFieldName;
+    }
+
+    public Double getTrainingPercent() {
+        return trainingPercent;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
+        if (lambda != null) {
+            builder.field(LAMBDA.getPreferredName(), lambda);
+        }
+        if (gamma != null) {
+            builder.field(GAMMA.getPreferredName(), gamma);
+        }
+        if (eta != null) {
+            builder.field(ETA.getPreferredName(), eta);
+        }
+        if (maximumNumberTrees != null) {
+            builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
+        }
+        if (featureBagFraction != null) {
+            builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
+        }
+        if (predictionFieldName != null) {
+            builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
+        }
+        if (trainingPercent != null) {
+            builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
+            trainingPercent);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Regression that = (Regression) o;
+        return Objects.equals(dependentVariable, that.dependentVariable)
+            && Objects.equals(lambda, that.lambda)
+            && Objects.equals(gamma, that.gamma)
+            && Objects.equals(eta, that.eta)
+            && Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
+            && Objects.equals(featureBagFraction, that.featureBagFraction)
+            && Objects.equals(predictionFieldName, that.predictionFieldName)
+            && Objects.equals(trainingPercent, that.trainingPercent);
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this);
+    }
+
+    public static class Builder {
+        private String dependentVariable;
+        private Double lambda;
+        private Double gamma;
+        private Double eta;
+        private Integer maximumNumberTrees;
+        private Double featureBagFraction;
+        private String predictionFieldName;
+        private Double trainingPercent;
+
+        private Builder(String dependentVariable) {
+            this.dependentVariable = Objects.requireNonNull(dependentVariable);
+        }
+
+        public Builder setLambda(Double lambda) {
+            this.lambda = lambda;
+            return this;
+        }
+
+        public Builder setGamma(Double gamma) {
+            this.gamma = gamma;
+            return this;
+        }
+
+        public Builder setEta(Double eta) {
+            this.eta = eta;
+            return this;
+        }
+
+        public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
+            this.maximumNumberTrees = maximumNumberTrees;
+            return this;
+        }
+
+        public Builder setFeatureBagFraction(Double featureBagFraction) {
+            this.featureBagFraction = featureBagFraction;
+            return this;
+        }
+
+        public Builder setPredictionFieldName(String predictionFieldName) {
+            this.predictionFieldName = predictionFieldName;
+            return this;
+        }
+
+        public Builder setTrainingPercent(Double trainingPercent) {
+            this.trainingPercent = trainingPercent;
+            return this;
+        }
+
+        public Regression build() {
+            return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
+                trainingPercent);
+        }
+    }
+}

+ 37 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -1215,9 +1215,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(remainingIds, not(hasItem(deletedEvent)));
     }
 
-    public void testPutDataFrameAnalyticsConfig() throws Exception {
+    public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
-        String configId = "put-test-config";
+        String configId = "test-put-df-analytics-outlier-detection";
         DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
             .setId(configId)
             .setSource(DataFrameAnalyticsSource.builder()
@@ -1247,6 +1247,41 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(createdConfig.getDescription(), equalTo("some description"));
     }
 
+    public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+        String configId = "test-put-df-analytics-regression";
+        DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
+            .setId(configId)
+            .setSource(DataFrameAnalyticsSource.builder()
+                .setIndex("put-test-source-index")
+                .build())
+            .setDest(DataFrameAnalyticsDest.builder()
+                .setIndex("put-test-dest-index")
+                .build())
+            .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression
+                .builder("my_dependent_variable")
+                .setTrainingPercent(80.0)
+                .build())
+            .setDescription("this is a regression")
+            .build();
+
+        createIndex("put-test-source-index", defaultMappingForTest());
+
+        PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
+            new PutDataFrameAnalyticsRequest(config),
+            machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
+        DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
+        assertThat(createdConfig.getId(), equalTo(config.getId()));
+        assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
+        assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder())));  // default value
+        assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
+        assertThat(createdConfig.getDest().getResultsField(), equalTo("ml"));  // default value
+        assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
+        assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
+        assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", "")));  // default value
+        assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
+    }
+
     public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
         String configId = "get-test-config";

+ 3 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -20,7 +20,6 @@
 package org.elasticsearch.client;
 
 import com.fasterxml.jackson.core.JsonParseException;
-
 import org.apache.http.HttpEntity;
 import org.apache.http.HttpHost;
 import org.apache.http.HttpResponse;
@@ -677,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(36, namedXContents.size());
+        assertEquals(37, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -711,8 +710,9 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(ShrinkAction.NAME));
         assertTrue(names.contains(FreezeAction.NAME));
         assertTrue(names.contains(SetPriorityAction.NAME));
-        assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class));
+        assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
         assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
+        assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
         assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
         assertTrue(names.contains(TimeSyncConfig.NAME));
         assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));

+ 17 - 4
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
 import org.elasticsearch.client.ml.dataframe.QueryConfig;
+import org.elasticsearch.client.ml.dataframe.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -2923,16 +2924,28 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .build();
             // end::put-data-frame-analytics-dest-config
 
-            // tag::put-data-frame-analytics-analysis-default
+            // tag::put-data-frame-analytics-outlier-detection-default
             DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1>
-            // end::put-data-frame-analytics-analysis-default
+            // end::put-data-frame-analytics-outlier-detection-default
 
-            // tag::put-data-frame-analytics-analysis-customized
+            // tag::put-data-frame-analytics-outlier-detection-customized
             DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1>
                 .setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2>
                 .setNNeighbors(5) // <3>
                 .build();
-            // end::put-data-frame-analytics-analysis-customized
+            // end::put-data-frame-analytics-outlier-detection-customized
+
+            // tag::put-data-frame-analytics-regression
+            DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
+                .setLambda(1.0) // <2>
+                .setGamma(5.5) // <3>
+                .setEta(5.5) // <4>
+                .setMaximumNumberTrees(50) // <5>
+                .setFeatureBagFraction(0.4) // <6>
+                .setPredictionFieldName("my_prediction_field_name") // <7>
+                .setTrainingPercent(50.0) // <8>
+                .build();
+            // end::put-data-frame-analytics-regression
 
             // tag::put-data-frame-analytics-analyzed-fields
             FetchSourceContext analyzedFields =

+ 54 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java

@@ -0,0 +1,54 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class RegressionTests extends AbstractXContentTestCase<Regression> {
+
+    public static Regression randomRegression() {
+        return Regression.builder(randomAlphaOfLength(10))
+            .setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
+            .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
+            .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
+            .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
+            .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
+            .build();
+    }
+
+    @Override
+    protected Regression createTestInstance() {
+        return randomRegression();
+    }
+
+    @Override
+    protected Regression doParseInstance(XContentParser parser) throws IOException {
+        return Regression.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 24 - 4
docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

@@ -75,25 +75,45 @@ include-tagged::{doc-tests-file}[{api}-dest-config]
 ==== Analysis
 
 The analysis to be performed.
-Currently, only one analysis is supported: +OutlierDetection+.
+Currently, the supported analyses include : +OutlierDetection+, +Regression+.
+
+===== Outlier Detection
 
 +OutlierDetection+ analysis can be created in one of two ways:
 
 ["source","java",subs="attributes,callouts,macros"]
 --------------------------------------------------
-include-tagged::{doc-tests-file}[{api}-analysis-default]
+include-tagged::{doc-tests-file}[{api}-outlier-detection-default]
 --------------------------------------------------
 <1> Constructing a new OutlierDetection object with default strategy to determine outliers
 
 or
 ["source","java",subs="attributes,callouts,macros"]
 --------------------------------------------------
-include-tagged::{doc-tests-file}[{api}-analysis-customized]
+include-tagged::{doc-tests-file}[{api}-outlier-detection-customized]
 --------------------------------------------------
 <1> Constructing a new OutlierDetection object
 <2> The method used to perform the analysis
 <3> Number of neighbors taken into account during analysis
 
+===== Regression
+
++Regression+ analysis requires to set which is the +dependent_variable+ and
+has a number of other optional parameters:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-regression]
+--------------------------------------------------
+<1> Constructing a new Regression builder object with the required dependent variable
+<2> The lambda regularization parameter. A non-negative double.
+<3> The gamma regularization parameter. A non-negative double.
+<4> The applied shrinkage. A double in [0.001, 1].
+<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
+<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
+<7> The name of the prediction field in the results object.
+<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
+
 ==== Analyzed fields
 
 FetchContext object containing fields to be included in / excluded from the analysis
@@ -113,4 +133,4 @@ The returned +{response}+ contains the newly created {dataframe-analytics-config
 ["source","java",subs="attributes,callouts,macros"]
 --------------------------------------------------
 include-tagged::{doc-tests-file}[{api}-response]
---------------------------------------------------
+--------------------------------------------------