Pārlūkot izejas kodu

Allow the user to specify 'query' in Evaluate Data Frame request (#45775)

Przemysław Witek 6 gadi atpakaļ
vecāks
revīzija
31f6e78acd
19 mainītis faili ar 414 papildinājumiem un 108 dzēšanām
  1. 33 11
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java
  2. 2 10
      client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
  3. 70 19
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  4. 24 23
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  5. 84 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameRequestTests.java
  6. 6 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java
  7. 6 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java
  8. 9 8
      docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc
  9. 7 1
      docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc
  10. 55 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java
  11. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java
  12. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java
  13. 4 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java
  14. 6 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java
  15. 29 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java
  16. 19 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java
  17. 19 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java
  18. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java
  19. 35 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 33 - 11
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java

@@ -21,7 +21,9 @@ package org.elasticsearch.client.ml;
 
 import org.elasticsearch.client.Validatable;
 import org.elasticsearch.client.ValidationException;
+import org.elasticsearch.client.ml.dataframe.QueryConfig;
 import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -37,20 +39,25 @@ import java.util.Objects;
 import java.util.Optional;
 
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
 
 public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
 
     private static final ParseField INDEX = new ParseField("index");
+    private static final ParseField QUERY = new ParseField("query");
     private static final ParseField EVALUATION = new ParseField("evaluation");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<EvaluateDataFrameRequest, Void> PARSER =
         new ConstructingObjectParser<>(
-            "evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List<String>) args[0], (Evaluation) args[1]));
+            "evaluate_data_frame_request",
+            true,
+            args -> new EvaluateDataFrameRequest((List<String>) args[0], (QueryConfig) args[1], (Evaluation) args[2]));
 
     static {
         PARSER.declareStringArray(constructorArg(), INDEX);
+        PARSER.declareObject(optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY);
         PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
     }
 
@@ -67,14 +74,16 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
     }
 
     private List<String> indices;
+    private QueryConfig queryConfig;
     private Evaluation evaluation;
 
-    public EvaluateDataFrameRequest(String index, Evaluation evaluation) {
-        this(Arrays.asList(index), evaluation);
+    public EvaluateDataFrameRequest(String index, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
+        this(Arrays.asList(index), queryConfig, evaluation);
     }
 
-    public EvaluateDataFrameRequest(List<String> indices, Evaluation evaluation) {
+    public EvaluateDataFrameRequest(List<String> indices, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
         setIndices(indices);
+        setQueryConfig(queryConfig);
         setEvaluation(evaluation);
     }
 
@@ -87,6 +96,14 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
         this.indices = new ArrayList<>(indices);
     }
 
+    public QueryConfig getQueryConfig() {
+        return queryConfig;
+    }
+
+    public final void setQueryConfig(QueryConfig queryConfig) {
+        this.queryConfig = queryConfig;
+    }
+
     public Evaluation getEvaluation() {
         return evaluation;
     }
@@ -111,18 +128,22 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        return builder
-            .startObject()
-                .array(INDEX.getPreferredName(), indices.toArray())
-                .startObject(EVALUATION.getPreferredName())
-                    .field(evaluation.getName(), evaluation)
-                .endObject()
+        builder.startObject();
+        builder.array(INDEX.getPreferredName(), indices.toArray());
+        if (queryConfig != null) {
+            builder.field(QUERY.getPreferredName(), queryConfig.getQuery());
+        }
+        builder
+            .startObject(EVALUATION.getPreferredName())
+                .field(evaluation.getName(), evaluation)
             .endObject();
+        builder.endObject();
+        return builder;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(indices, evaluation);
+        return Objects.hash(indices, queryConfig, evaluation);
     }
 
     @Override
@@ -131,6 +152,7 @@ public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
         if (o == null || getClass() != o.getClass()) return false;
         EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o;
         return Objects.equals(indices, that.indices)
+            && Objects.equals(queryConfig, that.queryConfig)
             && Objects.equals(evaluation, that.evaluation);
     }
 }

+ 2 - 10
client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

@@ -36,6 +36,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
 import org.elasticsearch.client.ml.DeleteJobRequest;
 import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
 import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
+import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
 import org.elasticsearch.client.ml.FindFileStructureRequest;
 import org.elasticsearch.client.ml.FindFileStructureRequestTests;
 import org.elasticsearch.client.ml.FlushJobRequest;
@@ -85,9 +86,6 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
 import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
-import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
-import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
 import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
 import org.elasticsearch.client.ml.job.config.AnalysisConfig;
 import org.elasticsearch.client.ml.job.config.Detector;
@@ -779,13 +777,7 @@ public class MLRequestConvertersTests extends ESTestCase {
     }
 
     public void testEvaluateDataFrame() throws IOException {
-        EvaluateDataFrameRequest evaluateRequest =
-            new EvaluateDataFrameRequest(
-                Arrays.asList(generateRandomStringArray(1, 10, false, false)),
-                new BinarySoftClassification(
-                    randomAlphaOfLengthBetween(1, 10),
-                    randomAlphaOfLengthBetween(1, 10),
-                    PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7)));
+        EvaluateDataFrameRequest evaluateRequest = EvaluateDataFrameRequestTests.createRandom();
         Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest);
         assertEquals(HttpPost.METHOD_NAME, request.getMethod());
         assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint());

+ 70 - 19
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -149,6 +149,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
 import org.junit.After;
@@ -1427,7 +1428,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
     public void testStopDataFrameAnalyticsConfig() throws Exception {
         String sourceIndex = "stop-test-source-index";
         String destIndex = "stop-test-dest-index";
-        createIndex(sourceIndex, mappingForClassification());
+        createIndex(sourceIndex, defaultMappingForTest());
         highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
 
@@ -1525,27 +1526,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(exception.status().getStatus(), equalTo(404));
     }
 
-    public void testEvaluateDataFrame() throws IOException {
+    public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
         String indexName = "evaluate-test-index";
         createIndex(indexName, mappingForClassification());
         BulkRequest bulk = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .add(docForClassification(indexName, false, 0.1))  // #0
-            .add(docForClassification(indexName, false, 0.2))  // #1
-            .add(docForClassification(indexName, false, 0.3))  // #2
-            .add(docForClassification(indexName, false, 0.4))  // #3
-            .add(docForClassification(indexName, false, 0.7))  // #4
-            .add(docForClassification(indexName, true, 0.2))  // #5
-            .add(docForClassification(indexName, true, 0.3))  // #6
-            .add(docForClassification(indexName, true, 0.4))  // #7
-            .add(docForClassification(indexName, true, 0.8))  // #8
-            .add(docForClassification(indexName, true, 0.9));  // #9
+            .add(docForClassification(indexName, "blue", false, 0.1))  // #0
+            .add(docForClassification(indexName, "blue", false, 0.2))  // #1
+            .add(docForClassification(indexName, "blue", false, 0.3))  // #2
+            .add(docForClassification(indexName, "blue", false, 0.4))  // #3
+            .add(docForClassification(indexName, "blue", false, 0.7))  // #4
+            .add(docForClassification(indexName, "blue", true, 0.2))  // #5
+            .add(docForClassification(indexName, "green", true, 0.3))  // #6
+            .add(docForClassification(indexName, "green", true, 0.4))  // #7
+            .add(docForClassification(indexName, "green", true, 0.8))  // #8
+            .add(docForClassification(indexName, "green", true, 0.9));  // #9
         highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
 
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
         EvaluateDataFrameRequest evaluateDataFrameRequest =
             new EvaluateDataFrameRequest(
                 indexName,
+                null,
                 new BinarySoftClassification(
                     actualField,
                     probabilityField,
@@ -1596,7 +1598,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
         assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
         assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
+    }
+
+    public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
+        String indexName = "evaluate-with-query-test-index";
+        createIndex(indexName, mappingForClassification());
+        BulkRequest bulk = new BulkRequest()
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .add(docForClassification(indexName, "blue", true, 1.0))  // #0
+            .add(docForClassification(indexName, "blue", true, 1.0))  // #1
+            .add(docForClassification(indexName, "blue", true, 1.0))  // #2
+            .add(docForClassification(indexName, "blue", true, 1.0))  // #3
+            .add(docForClassification(indexName, "blue", true, 0.0))  // #4
+            .add(docForClassification(indexName, "blue", true, 0.0))  // #5
+            .add(docForClassification(indexName, "green", true, 0.0))  // #6
+            .add(docForClassification(indexName, "green", true, 0.0))  // #7
+            .add(docForClassification(indexName, "green", true, 0.0))  // #8
+            .add(docForClassification(indexName, "green", true, 1.0));  // #9
+        highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
 
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+        EvaluateDataFrameRequest evaluateDataFrameRequest =
+            new EvaluateDataFrameRequest(
+                indexName,
+                // Request only "blue" subset to be evaluated
+                new QueryConfig(QueryBuilders.termQuery(datasetField, "blue")),
+                new BinarySoftClassification(actualField, probabilityField, ConfusionMatrixMetric.at(0.5)));
+
+        EvaluateDataFrameResponse evaluateDataFrameResponse =
+            execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+        ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME);
+        assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME));
+        ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5");
+        assertThat(confusionMatrix.getTruePositives(), equalTo(4L));  // docs #0, #1, #2 and #3
+        assertThat(confusionMatrix.getFalsePositives(), equalTo(0L));
+        assertThat(confusionMatrix.getTrueNegatives(), equalTo(0L));
+        assertThat(confusionMatrix.getFalseNegatives(), equalTo(2L));  // docs #4 and #5
+    }
+
+    public void testEvaluateDataFrame_Regression() throws IOException {
         String regressionIndex = "evaluate-regression-test-index";
         createIndex(regressionIndex, mappingForRegression());
         BulkRequest regressionBulk = new BulkRequest()
@@ -1613,10 +1656,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             .add(docForRegression(regressionIndex, 0.5, 0.9));  // #9
         highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
 
-        evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
-            new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+        EvaluateDataFrameRequest evaluateDataFrameRequest =
+            new EvaluateDataFrameRequest(
+                regressionIndex,
+                null,
+                new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
 
-        evaluateDataFrameResponse =
+        EvaluateDataFrameResponse evaluateDataFrameResponse =
             execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
         assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
@@ -1643,12 +1690,16 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         .endObject();
     }
 
+    private static final String datasetField = "dataset";
     private static final String actualField = "label";
     private static final String probabilityField = "p";
 
     private static XContentBuilder mappingForClassification() throws IOException {
         return XContentFactory.jsonBuilder().startObject()
             .startObject("properties")
+                .startObject(datasetField)
+                    .field("type", "keyword")
+                .endObject()
                 .startObject(actualField)
                     .field("type", "keyword")
                 .endObject()
@@ -1659,10 +1710,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         .endObject();
     }
 
-    private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) {
+    private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
         return new IndexRequest()
             .index(indexName)
-            .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
+            .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
     }
 
     private static final String actualRegression = "regression_actual";
@@ -1697,7 +1748,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         BulkRequest bulk1 = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
         for (int i = 0; i < 10; ++i) {
-            bulk1.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
+            bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
         }
         highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
 
@@ -1723,7 +1774,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         BulkRequest bulk2 = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
         for (int i = 10; i < 100; ++i) {
-            bulk2.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
+            bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
         }
         highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
 

+ 24 - 23
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -178,7 +178,6 @@ import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
 import org.elasticsearch.tasks.TaskId;
-import org.hamcrest.CoreMatchers;
 import org.junit.After;
 
 import java.io.IOException;
@@ -3179,16 +3178,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
         BulkRequest bulkRequest =
             new BulkRequest(indexName)
                 .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0
-                .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1
-                .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2
-                .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3
-                .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4
-                .add(new IndexRequest().source(XContentType.JSON, "label", true,  "p", 0.2)) // #5
-                .add(new IndexRequest().source(XContentType.JSON, "label", true,  "p", 0.3)) // #6
-                .add(new IndexRequest().source(XContentType.JSON, "label", true,  "p", 0.4)) // #7
-                .add(new IndexRequest().source(XContentType.JSON, "label", true,  "p", 0.8)) // #8
-                .add(new IndexRequest().source(XContentType.JSON, "label", true,  "p", 0.9)); // #9
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.1)) // #0
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.2)) // #1
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.3)) // #2
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.4)) // #3
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", false, "p", 0.7)) // #4
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true,  "p", 0.2)) // #5
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true,  "p", 0.3)) // #6
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true,  "p", 0.4)) // #7
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true,  "p", 0.8)) // #8
+                .add(new IndexRequest().source(XContentType.JSON, "dataset", "blue", "label", true,  "p", 0.9)); // #9
         RestHighLevelClient client = highLevelClient();
         client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
         client.bulk(bulkRequest, RequestOptions.DEFAULT);
@@ -3196,14 +3195,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             // tag::evaluate-data-frame-request
             EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
                 indexName, // <2>
-                new BinarySoftClassification( // <3>
-                    "label", // <4>
-                    "p", // <5>
-                    // Evaluation metrics // <6>
-                    PrecisionMetric.at(0.4, 0.5, 0.6), // <7>
-                    RecallMetric.at(0.5, 0.7), // <8>
-                    ConfusionMatrixMetric.at(0.5), // <9>
-                    AucRocMetric.withCurve())); // <10>
+                new QueryConfig(QueryBuilders.termQuery("dataset", "blue")),  // <3>
+                new BinarySoftClassification( // <4>
+                    "label", // <5>
+                    "p", // <6>
+                    // Evaluation metrics // <7>
+                    PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
+                    RecallMetric.at(0.5, 0.7), // <9>
+                    ConfusionMatrixMetric.at(0.5), // <10>
+                    AucRocMetric.withCurve())); // <11>
             // end::evaluate-data-frame-request
 
             // tag::evaluate-data-frame-execute
@@ -3224,14 +3224,15 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
                 containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
             assertThat(precision, closeTo(0.6, 1e-9));
-            assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L));  // docs #8 and #9
-            assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L));  // doc #4
-            assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L));  // docs #0, #1, #2 and #3
-            assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L));  // docs #5, #6 and #7
+            assertThat(confusionMatrix.getTruePositives(), equalTo(2L));  // docs #8 and #9
+            assertThat(confusionMatrix.getFalsePositives(), equalTo(1L));  // doc #4
+            assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L));  // docs #0, #1, #2 and #3
+            assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L));  // docs #5, #6 and #7
         }
         {
             EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(
                 indexName,
+                new QueryConfig(QueryBuilders.termQuery("dataset", "blue")),
                 new BinarySoftClassification(
                     "label",
                     "p",

+ 84 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameRequestTests.java

@@ -0,0 +1,84 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.dataframe.QueryConfig;
+import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.RegressionTests;
+import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Predicate;
+
+import static java.util.function.Predicate.not;
+
+public class EvaluateDataFrameRequestTests extends AbstractXContentTestCase<EvaluateDataFrameRequest> {
+
+    public static EvaluateDataFrameRequest createRandom() {
+        int indicesCount = randomIntBetween(1, 5);
+        List<String> indices = new ArrayList<>(indicesCount);
+        for (int i = 0; i < indicesCount; i++) {
+            indices.add(randomAlphaOfLength(10));
+        }
+        QueryConfig queryConfig = randomBoolean()
+            ? new QueryConfig(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)))
+            : null;
+        Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
+        return new EvaluateDataFrameRequest(indices, queryConfig, evaluation);
+    }
+
+    @Override
+    protected EvaluateDataFrameRequest createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected EvaluateDataFrameRequest doParseInstance(XContentParser parser) throws IOException {
+        return EvaluateDataFrameRequest.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        // allow unknown fields in root only
+        return not(String::isEmpty);
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+        return new NamedXContentRegistry(namedXContent);
+    }
+}

+ 6 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -36,8 +36,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
         return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
     }
 
-    @Override
-    protected Regression createTestInstance() {
+    public static Regression createRandom() {
         List<EvaluationMetric> metrics = new ArrayList<>();
         if (randomBoolean()) {
             metrics.add(new MeanSquaredErrorMetric());
@@ -50,6 +49,11 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
             new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }
 
+    @Override
+    protected Regression createTestInstance() {
+        return createRandom();
+    }
+
     @Override
     protected Regression doParseInstance(XContentParser parser) throws IOException {
         return Regression.fromXContent(parser);

+ 6 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java

@@ -37,8 +37,7 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase<Bina
         return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
     }
 
-    @Override
-    protected BinarySoftClassification createTestInstance() {
+    public static BinarySoftClassification createRandom() {
         List<EvaluationMetric> metrics = new ArrayList<>();
         if (randomBoolean()) {
             metrics.add(new AucRocMetric(randomBoolean()));
@@ -66,6 +65,11 @@ public class BinarySoftClassificationTests extends AbstractXContentTestCase<Bina
             new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }
 
+    @Override
+    protected BinarySoftClassification createTestInstance() {
+        return createRandom();
+    }
+
     @Override
     protected BinarySoftClassification doParseInstance(XContentParser parser) throws IOException {
         return BinarySoftClassification.fromXContent(parser);

+ 9 - 8
docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

@@ -18,14 +18,15 @@ include-tagged::{doc-tests-file}[{api}-request]
 --------------------------------------------------
 <1> Constructing a new evaluation request
 <2> Reference to an existing index
-<3> Kind of evaluation to perform
-<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
-<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
-<6> The remaining parameters are the metrics to be calculated based on the two fields described above.
-<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
-<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
-<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
-<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
+<3> The query with which to select data from indices
+<4> Kind of evaluation to perform
+<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false
+<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive
+<7> The remaining parameters are the metrics to be calculated based on the two fields described above.
+<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6
+<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7
+<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5
+<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned
 
 include::../execution.asciidoc[]
 

+ 7 - 1
docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc

@@ -43,7 +43,13 @@ packages together commonly used metrics for various analyses.
 `index`::
   (Required, object) Defines the `index` in which the evaluation will be
   performed.
-  
+
+`query`::
+  (Optional, object) Query used to select data from the index.
+  The {es} query domain-specific language (DSL). This value corresponds to the query
+  object in an {es} search POST body. By default, this property has the following
+  value: `{"match_all": {}}`.
+
 `evaluation`::
   (Required, object) Defines the type of evaluation you want to perform. For example: 
   `binary_soft_classification`. See <<ml-evaluate-dfanalytics-resources>>.

+ 55 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java

@@ -5,12 +5,14 @@
  */
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestBuilder;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.client.ElasticsearchClient;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -20,14 +22,21 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
 
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
 public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.Response> {
 
@@ -41,14 +50,20 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
     public static class Request extends ActionRequest implements ToXContentObject {
 
         private static final ParseField INDEX = new ParseField("index");
+        private static final ParseField QUERY = new ParseField("query");
         private static final ParseField EVALUATION = new ParseField("evaluation");
 
-        private static final ConstructingObjectParser<Request, Void> PARSER = new ConstructingObjectParser<>(NAME,
-            a -> new Request((List<String>) a[0], (Evaluation) a[1]));
+        private static final ConstructingObjectParser<Request, Void> PARSER = new ConstructingObjectParser<>(
+            NAME,
+            a -> new Request((List<String>) a[0], (QueryProvider) a[1], (Evaluation) a[2]));
 
         static {
-            PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX);
-            PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
+            PARSER.declareStringArray(constructorArg(), INDEX);
+            PARSER.declareObject(
+                optionalConstructorArg(),
+                (p, c) -> QueryProvider.fromXContent(p, true, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT),
+                QUERY);
+            PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
         }
 
         private static Evaluation parseEvaluation(XContentParser parser) throws IOException {
@@ -64,19 +79,25 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
         }
 
         private String[] indices;
+        private QueryProvider queryProvider;
         private Evaluation evaluation;
 
-        private Request(List<String> indices, Evaluation evaluation) {
+        private Request(List<String> indices, @Nullable QueryProvider queryProvider, Evaluation evaluation) {
             setIndices(indices);
+            setQueryProvider(queryProvider);
             setEvaluation(evaluation);
         }
 
-        public Request() {
-        }
+        public Request() {}
 
         public Request(StreamInput in) throws IOException {
             super(in);
             indices = in.readStringArray();
+            if (in.getVersion().onOrAfter(Version.CURRENT)) {
+                if (in.readBoolean()) {
+                    queryProvider = QueryProvider.fromStream(in);
+                }
+            }
             evaluation = in.readNamedWriteable(Evaluation.class);
         }
 
@@ -92,6 +113,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
             this.indices = indices.toArray(new String[indices.size()]);
         }
 
+        public QueryBuilder getParsedQuery() {
+            return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
+        }
+
+        public final void setQueryProvider(QueryProvider queryProvider) {
+            this.queryProvider = queryProvider;
+        }
+
         public Evaluation getEvaluation() {
             return evaluation;
         }
@@ -109,6 +138,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
             out.writeStringArray(indices);
+            if (out.getVersion().onOrAfter(Version.CURRENT)) {
+                if (queryProvider != null) {
+                    out.writeBoolean(true);
+                    queryProvider.writeTo(out);
+                } else {
+                    out.writeBoolean(false);
+                }
+            }
             out.writeNamedWriteable(evaluation);
         }
 
@@ -116,16 +153,20 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
             builder.array(INDEX.getPreferredName(), indices);
-            builder.startObject(EVALUATION.getPreferredName());
-            builder.field(evaluation.getName(), evaluation);
-            builder.endObject();
+            if (queryProvider != null) {
+                builder.field(QUERY.getPreferredName(), queryProvider.getQuery());
+            }
+            builder
+                .startObject(EVALUATION.getPreferredName())
+                    .field(evaluation.getName(), evaluation)
+                .endObject();
             builder.endObject();
             return builder;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(Arrays.hashCode(indices), evaluation);
+            return Objects.hash(Arrays.hashCode(indices), queryProvider, evaluation);
         }
 
         @Override
@@ -133,7 +174,9 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
             Request that = (Request) o;
-            return Arrays.equals(indices, that.indices) && Objects.equals(evaluation, that.evaluation);
+            return Arrays.equals(indices, that.indices)
+                && Objects.equals(queryProvider, that.queryProvider)
+                && Objects.equals(evaluation, that.evaluation);
         }
     }
 
@@ -200,5 +243,4 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
             return Strings.toString(this);
         }
     }
-
 }

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java

@@ -143,7 +143,8 @@ public class DataFrameAnalyticsSource implements Writeable, ToXContentObject {
         return deprecations;
     }
 
-    public Map<String, Object> getQuery() {
+    // Visible for testing
+    Map<String, Object> getQuery() {
         return queryProvider.getQuery();
     }
 }

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

@@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 
 import java.util.List;
@@ -25,8 +26,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
 
     /**
      * Builds the search required to collect data to compute the evaluation result
+     * @param queryBuilder User-provided query that must be respected when collecting data
      */
-    SearchSourceBuilder buildSearch();
+    SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
 
     /**
      * Computes the evaluation result

+ 4 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -106,10 +107,11 @@ public class Regression implements Evaluation {
     }
 
     @Override
-    public SearchSourceBuilder buildSearch() {
+    public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
         BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
             .filter(QueryBuilders.existsQuery(actualField))
-            .filter(QueryBuilders.existsQuery(predictedField));
+            .filter(QueryBuilders.existsQuery(predictedField))
+            .filter(queryBuilder);
         SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
         for (RegressionMetric metric : metrics) {
             List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);

+ 6 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java

@@ -155,10 +155,12 @@ public class BinarySoftClassification implements Evaluation {
     }
 
     @Override
-    public SearchSourceBuilder buildSearch() {
-        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
-        searchSourceBuilder.size(0);
-        searchSourceBuilder.query(buildQuery());
+    public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
+        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
+            .filter(QueryBuilders.existsQuery(actualField))
+            .filter(QueryBuilders.existsQuery(predictedProbabilityField))
+            .filter(queryBuilder);
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
         for (SoftClassificationMetric metric : metrics) {
             List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
             aggs.forEach(searchSourceBuilder::aggregation);
@@ -166,13 +168,6 @@ public class BinarySoftClassification implements Evaluation {
         return searchSourceBuilder;
     }
 
-    private QueryBuilder buildQuery() {
-        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
-        boolQuery.filter(QueryBuilders.existsQuery(actualField));
-        boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField));
-        return boolQuery;
-    }
-
     @Override
     public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
         if (searchResponse.getHits().getTotalHits().value == 0) {

+ 29 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java

@@ -7,26 +7,41 @@ package org.elasticsearch.xpack.core.ml.action;
 
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionTests;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
 
+import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTestCase<Request> {
 
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
+        namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
+        return new NamedWriteableRegistry(namedWriteables);
     }
 
     @Override
     protected NamedXContentRegistry xContentRegistry() {
-        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
     }
 
     @Override
@@ -38,7 +53,18 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
             indices.add(randomAlphaOfLength(10));
         }
         request.setIndices(indices);
-        request.setEvaluation(BinarySoftClassificationTests.createRandom());
+        QueryProvider queryProvider = null;
+        if (randomBoolean()) {
+            try {
+                queryProvider = QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)));
+            } catch (IOException e) {
+                // Should never happen
+                throw new UncheckedIOException(e);
+            }
+        }
+        request.setQueryProvider(queryProvider);
+        Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
+        request.setEvaluation(evaluation);
         return request;
     }
 

+ 19 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
@@ -69,4 +72,20 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
             () -> new Regression("foo", "bar", Collections.emptyList()));
         assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
     }
+
+    public void testBuildSearch() {
+        Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError()));
+        QueryBuilder userProvidedQuery =
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.termQuery("field_A", "some-value"))
+                .filter(QueryBuilders.termQuery("field_B", "some-other-value"));
+        QueryBuilder expectedSearchQuery =
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.existsQuery("act"))
+                .filter(QueryBuilders.existsQuery("prob"))
+                .filter(QueryBuilders.boolQuery()
+                    .filter(QueryBuilders.termQuery("field_A", "some-value"))
+                    .filter(QueryBuilders.termQuery("field_B", "some-other-value")));
+        assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
+    }
 }

+ 19 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java

@@ -10,11 +10,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
@@ -76,4 +79,20 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
             () -> new BinarySoftClassification("foo", "bar", Collections.emptyList()));
         assertThat(e.getMessage(), equalTo("[binary_soft_classification] must have one or more metrics"));
     }
+
+    public void testBuildSearch() {
+        BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
+        QueryBuilder userProvidedQuery =
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.termQuery("field_A", "some-value"))
+                .filter(QueryBuilders.termQuery("field_B", "some-other-value"));
+        QueryBuilder expectedSearchQuery =
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.existsQuery("act"))
+                .filter(QueryBuilders.existsQuery("prob"))
+                .filter(QueryBuilders.boolQuery()
+                    .filter(QueryBuilders.termQuery("field_A", "some-value"))
+                    .filter(QueryBuilders.termQuery("field_B", "some-other-value")));
+        assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
+    }
 }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java

@@ -40,7 +40,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
                              ActionListener<EvaluateDataFrameAction.Response> listener) {
         Evaluation evaluation = request.getEvaluation();
         SearchRequest searchRequest = new SearchRequest(request.getIndices());
-        searchRequest.source(evaluation.buildSearch());
+        searchRequest.source(evaluation.buildSearch(request.getParsedQuery()));
 
         ActionListener<List<EvaluationMetricResult>> resultsListener = ActionListener.wrap(
             results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)),

+ 35 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -5,6 +5,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "blue",
             "is_outlier": false,
             "is_outlier_int": 0,
             "outlier_score": 0.0,
@@ -19,6 +20,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "blue",
             "is_outlier": false,
             "is_outlier_int": 0,
             "outlier_score": 0.2,
@@ -33,6 +35,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "blue",
             "is_outlier": false,
             "is_outlier_int": 0,
             "outlier_score": 0.3,
@@ -47,6 +50,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "blue",
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.3,
@@ -61,6 +65,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "green",
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.4,
@@ -75,6 +80,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "green",
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.5,
@@ -89,6 +95,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "green",
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.9,
@@ -103,6 +110,7 @@ setup:
         index: utopia
         body:  >
           {
+            "dataset": "green",
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.95,
@@ -305,6 +313,33 @@ setup:
             tn: 3
             fn: 2
 
+---
+"Test binary_soft_classification with query":
+  - do:
+      ml.evaluate_data_frame:
+        body:  >
+          {
+            "index": "utopia",
+            "query": { "bool": { "filter": { "term": { "dataset": "blue" } } } },
+            "evaluation": {
+              "binary_soft_classification": {
+                "actual_field": "is_outlier",
+                "predicted_probability_field": "outlier_score",
+                "metrics": {
+                  "confusion_matrix": { "at": [0.5] }
+                }
+              }
+            }
+          }
+  - match:
+      binary_soft_classification:
+        confusion_matrix:
+          '0.5':
+            tp: 0
+            fp: 0
+            tn: 3
+            fn: 1
+
 ---
 "Test binary_soft_classification default metrics":
   - do: