|  | @@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.OutlierDetection;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.PhaseProgress;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.QueryConfig;
 | 
	
		
			
				|  |  | +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 | 
	
		
			
				|  |  | +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
 | 
	
	
		
			
				|  | @@ -1573,19 +1575,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
 | 
	
		
			
				|  |  |          String indexName = "evaluate-test-index";
 | 
	
		
			
				|  |  | -        createIndex(indexName, mappingForClassification());
 | 
	
		
			
				|  |  | +        createIndex(indexName, mappingForSoftClassification());
 | 
	
		
			
				|  |  |          BulkRequest bulk = new BulkRequest()
 | 
	
		
			
				|  |  |              .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", false, 0.1))  // #0
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", false, 0.2))  // #1
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", false, 0.3))  // #2
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", false, 0.4))  // #3
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", false, 0.7))  // #4
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 0.2))  // #5
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.3))  // #6
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.4))  // #7
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.8))  // #8
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.9));  // #9
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", false, 0.1))  // #0
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", false, 0.2))  // #1
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", false, 0.3))  // #2
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", false, 0.4))  // #3
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", false, 0.7))  // #4
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 0.2))  // #5
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.3))  // #6
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.4))  // #7
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.8))  // #8
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.9));  // #9
 | 
	
		
			
				|  |  |          highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
 | 
	
	
		
			
				|  | @@ -1647,19 +1649,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
 | 
	
		
			
				|  |  |          String indexName = "evaluate-with-query-test-index";
 | 
	
		
			
				|  |  | -        createIndex(indexName, mappingForClassification());
 | 
	
		
			
				|  |  | +        createIndex(indexName, mappingForSoftClassification());
 | 
	
		
			
				|  |  |          BulkRequest bulk = new BulkRequest()
 | 
	
		
			
				|  |  |              .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 1.0))  // #0
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 1.0))  // #1
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 1.0))  // #2
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 1.0))  // #3
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 0.0))  // #4
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "blue", true, 0.0))  // #5
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.0))  // #6
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.0))  // #7
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 0.0))  // #8
 | 
	
		
			
				|  |  | -            .add(docForClassification(indexName, "green", true, 1.0));  // #9
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #0
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #1
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #2
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #3
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 0.0))  // #4
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "blue", true, 0.0))  // #5
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.0))  // #6
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.0))  // #7
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 0.0))  // #8
 | 
	
		
			
				|  |  | +            .add(docForSoftClassification(indexName, "green", true, 1.0));  // #9
 | 
	
		
			
				|  |  |          highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
 | 
	
	
		
			
				|  | @@ -1722,6 +1724,74 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |          assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    public void testEvaluateDataFrame_Classification() throws IOException {
 | 
	
		
			
				|  |  | +        String indexName = "evaluate-classification-test-index";
 | 
	
		
			
				|  |  | +        createIndex(indexName, mappingForClassification());
 | 
	
		
			
				|  |  | +        BulkRequest regressionBulk = new BulkRequest()
 | 
	
		
			
				|  |  | +            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "cat", "cat"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "cat", "cat"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "cat", "cat"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "cat", "dog"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "cat", "fish"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "dog", "cat"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "dog", "dog"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "dog", "dog"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "dog", "dog"))
 | 
	
		
			
				|  |  | +            .add(docForClassification(indexName, "horse", "cat"));
 | 
	
		
			
				|  |  | +        highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        {  // No size provided for MulticlassConfusionMatrixMetric, default used instead
 | 
	
		
			
				|  |  | +            EvaluateDataFrameRequest evaluateDataFrameRequest =
 | 
	
		
			
				|  |  | +                new EvaluateDataFrameRequest(
 | 
	
		
			
				|  |  | +                    indexName,
 | 
	
		
			
				|  |  | +                    null,
 | 
	
		
			
				|  |  | +                    new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            EvaluateDataFrameResponse evaluateDataFrameResponse =
 | 
	
		
			
				|  |  | +                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
 | 
	
		
			
				|  |  | +            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
 | 
	
		
			
				|  |  | +            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            MulticlassConfusionMatrixMetric.Result mcmResult =
 | 
	
		
			
				|  |  | +                evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
 | 
	
		
			
				|  |  | +            assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
 | 
	
		
			
				|  |  | +            assertThat(
 | 
	
		
			
				|  |  | +                mcmResult.getConfusionMatrix(),
 | 
	
		
			
				|  |  | +                equalTo(
 | 
	
		
			
				|  |  | +                    Map.of(
 | 
	
		
			
				|  |  | +                        "cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L),
 | 
	
		
			
				|  |  | +                        "dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L),
 | 
	
		
			
				|  |  | +                        "horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L))));
 | 
	
		
			
				|  |  | +            assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        {  // Explicit size provided for MulticlassConfusionMatrixMetric metric
 | 
	
		
			
				|  |  | +            EvaluateDataFrameRequest evaluateDataFrameRequest =
 | 
	
		
			
				|  |  | +                new EvaluateDataFrameRequest(
 | 
	
		
			
				|  |  | +                    indexName,
 | 
	
		
			
				|  |  | +                    null,
 | 
	
		
			
				|  |  | +                    new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            EvaluateDataFrameResponse evaluateDataFrameResponse =
 | 
	
		
			
				|  |  | +                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
 | 
	
		
			
				|  |  | +            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
 | 
	
		
			
				|  |  | +            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            MulticlassConfusionMatrixMetric.Result mcmResult =
 | 
	
		
			
				|  |  | +                evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
 | 
	
		
			
				|  |  | +            assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
 | 
	
		
			
				|  |  | +            assertThat(
 | 
	
		
			
				|  |  | +                mcmResult.getConfusionMatrix(),
 | 
	
		
			
				|  |  | +                equalTo(
 | 
	
		
			
				|  |  | +                    Map.of(
 | 
	
		
			
				|  |  | +                        "cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L),
 | 
	
		
			
				|  |  | +                        "dog", Map.of("cat", 1L, "dog", 3L))));
 | 
	
		
			
				|  |  | +            assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private static XContentBuilder defaultMappingForTest() throws IOException {
 | 
	
		
			
				|  |  |          return XContentFactory.jsonBuilder().startObject()
 | 
	
		
			
				|  |  |              .startObject("properties")
 | 
	
	
		
			
				|  | @@ -1739,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |      private static final String actualField = "label";
 | 
	
		
			
				|  |  |      private static final String probabilityField = "p";
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private static XContentBuilder mappingForClassification() throws IOException {
 | 
	
		
			
				|  |  | +    private static XContentBuilder mappingForSoftClassification() throws IOException {
 | 
	
		
			
				|  |  |          return XContentFactory.jsonBuilder().startObject()
 | 
	
		
			
				|  |  |              .startObject("properties")
 | 
	
		
			
				|  |  |                  .startObject(datasetField)
 | 
	
	
		
			
				|  | @@ -1755,26 +1825,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |          .endObject();
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
 | 
	
		
			
				|  |  | +    private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) {
 | 
	
		
			
				|  |  |          return new IndexRequest()
 | 
	
		
			
				|  |  |              .index(indexName)
 | 
	
		
			
				|  |  |              .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    private static final String actualClassField = "actual_class";
 | 
	
		
			
				|  |  | +    private static final String predictedClassField = "predicted_class";
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static XContentBuilder mappingForClassification() throws IOException {
 | 
	
		
			
				|  |  | +        return XContentFactory.jsonBuilder().startObject()
 | 
	
		
			
				|  |  | +            .startObject("properties")
 | 
	
		
			
				|  |  | +                .startObject(actualClassField)
 | 
	
		
			
				|  |  | +                    .field("type", "keyword")
 | 
	
		
			
				|  |  | +                .endObject()
 | 
	
		
			
				|  |  | +                .startObject(predictedClassField)
 | 
	
		
			
				|  |  | +                    .field("type", "keyword")
 | 
	
		
			
				|  |  | +                .endObject()
 | 
	
		
			
				|  |  | +            .endObject()
 | 
	
		
			
				|  |  | +        .endObject();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
 | 
	
		
			
				|  |  | +        return new IndexRequest()
 | 
	
		
			
				|  |  | +            .index(indexName)
 | 
	
		
			
				|  |  | +            .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private static final String actualRegression = "regression_actual";
 | 
	
		
			
				|  |  |      private static final String probabilityRegression = "regression_prob";
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private static XContentBuilder mappingForRegression() throws IOException {
 | 
	
		
			
				|  |  |          return XContentFactory.jsonBuilder().startObject()
 | 
	
		
			
				|  |  |              .startObject("properties")
 | 
	
		
			
				|  |  | -            .startObject(actualRegression)
 | 
	
		
			
				|  |  | -            .field("type", "double")
 | 
	
		
			
				|  |  | -            .endObject()
 | 
	
		
			
				|  |  | -            .startObject(probabilityRegression)
 | 
	
		
			
				|  |  | -            .field("type", "double")
 | 
	
		
			
				|  |  | -            .endObject()
 | 
	
		
			
				|  |  | +                .startObject(actualRegression)
 | 
	
		
			
				|  |  | +                    .field("type", "double")
 | 
	
		
			
				|  |  | +                .endObject()
 | 
	
		
			
				|  |  | +                .startObject(probabilityRegression)
 | 
	
		
			
				|  |  | +                    .field("type", "double")
 | 
	
		
			
				|  |  | +                .endObject()
 | 
	
		
			
				|  |  |              .endObject()
 | 
	
		
			
				|  |  | -            .endObject();
 | 
	
		
			
				|  |  | +        .endObject();
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private static IndexRequest docForRegression(String indexName, double act, double p) {
 | 
	
	
		
			
				|  | @@ -1789,11 +1881,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testEstimateMemoryUsage() throws IOException {
 | 
	
		
			
				|  |  |          String indexName = "estimate-test-index";
 | 
	
		
			
				|  |  | -        createIndex(indexName, mappingForClassification());
 | 
	
		
			
				|  |  | +        createIndex(indexName, mappingForSoftClassification());
 | 
	
		
			
				|  |  |          BulkRequest bulk1 = new BulkRequest()
 | 
	
		
			
				|  |  |              .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
 | 
	
		
			
				|  |  |          for (int i = 0; i < 10; ++i) {
 | 
	
		
			
				|  |  | -            bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
 | 
	
		
			
				|  |  | +            bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -1819,7 +1911,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 | 
	
		
			
				|  |  |          BulkRequest bulk2 = new BulkRequest()
 | 
	
		
			
				|  |  |              .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
 | 
	
		
			
				|  |  |          for (int i = 10; i < 100; ++i) {
 | 
	
		
			
				|  |  | -            bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
 | 
	
		
			
				|  |  | +            bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
 | 
	
		
			
				|  |  |  
 |