فهرست منبع

[ML] Allow setting num_top_classes to a special value -1 (#63587)

Przemysław Witek 5 سال پیش
والد
کامیت
d9e7d88f08

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

@@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
             .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
             .setRandomizeSeed(randomBoolean() ? null : randomLong())
             .setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
-            .setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
+            .setNumTopClasses(randomBoolean() ? null : randomIntBetween(-1, 1000))
             .setFeatureProcessors(randomBoolean() ? null :
                 Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
                     OneHotEncodingTests.createRandom(),

+ 1 - 1
docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

@@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
 <9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
 <10> The seed to be used by the random generator that picks which rows are used in training.
 <11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
-<12> The number of top classes to be reported in the results. Defaults to 2.
+<12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2.
 <13> Custom feature processors that will create new features for analysis from the included document
      fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
 

+ 3 - 2
docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

@@ -136,8 +136,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
 `num_top_classes`::::
 (Optional, integer)
 Defines the number of categories for which the predicted probabilities are
-reported. It must be non-negative. If it is greater than the total number of
-categories, the API reports all category probabilities. Defaults to 2.
+reported. It must be non-negative or -1 (which denotes all categories). If it is
+greater than the total number of categories, the API reports all category
+probabilities. Defaults to 2.
 
 `num_top_feature_importance_values`::::
 (Optional, integer)

+ 3 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -167,8 +167,9 @@ public class Classification implements DataFrameAnalysis {
                           @Nullable Double trainingPercent,
                           @Nullable Long randomizeSeed,
                           @Nullable List<PreProcessor> featureProcessors) {
-        if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
-            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
+        if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
+            throw ExceptionsHelper.badRequestException(
+                "[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
         }
         if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
             throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());

+ 9 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
         Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
             null : randomFrom(Classification.ClassAssignmentObjective.values());
-        Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
+        Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
         Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
         Long randomizeSeed = randomBoolean() ? null : randomLong();
         return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
@@ -218,18 +218,18 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
 
-    public void testConstructor_GivenNumTopClassesIsLessThanZero() {
+    public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
 
-        assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
+        assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
     }
 
     public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
             () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
 
-        assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
+        assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
     }
 
     public void testGetPredictionFieldName() {
@@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
         assertThat(classification.getNumTopClasses(), equalTo(7));
 
+        // Special value: num_top_classes == -1
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
+        assertThat(classification.getNumTopClasses(), equalTo(-1));
+
         // Boundary condition: num_top_classes == 0
         classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
         assertThat(classification.getNumTopClasses(), equalTo(0));

+ 1 - 1
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -92,7 +92,7 @@ yamlRestTest {
     'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k',
     'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
     'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
-    'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
+    'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than minus one',
     'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
     'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
     'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',

+ 3 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -367,7 +367,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         String predictedClassField = dependentVariable + "_prediction";
         indexData(sourceIndex, 300, 0, dependentVariable);
 
-        int numTopClasses = 2;
+        int numTopClasses = randomBoolean() ? 2 : -1;  // Occasionally it's worth testing the special value -1.
+        int expectedNumTopClasses = 2;
         DataFrameAnalyticsConfig config =
             buildAnalytics(
                 jobId,
@@ -391,7 +392,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             Map<String, Object> destDoc = getDestDoc(config, hit);
             Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
             assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues)));
-            assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues);
+            assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues);
 
             // Let's just assert there's both training and non-training results
             //

+ 4 - 4
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

@@ -1868,10 +1868,10 @@ setup:
           }
 
 ---
-"Test put classification given num_top_classes is less than zero":
+"Test put classification given num_top_classes is less than minus one":
 
   - do:
-      catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
+      catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
       ml.put_data_frame_analytics:
         id: "classification-training-percent-is-less-than-one"
         body: >
@@ -1885,7 +1885,7 @@ setup:
             "analysis": {
               "classification": {
                 "dependent_variable": "foo",
-                "num_top_classes": -1
+                "num_top_classes": -2
               }
             }
           }
@@ -1894,7 +1894,7 @@ setup:
 "Test put classification given num_top_classes is greater than 1k":
 
   - do:
-      catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
+      catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
       ml.put_data_frame_analytics:
         id: "classification-training-percent-is-greater-than-hundred"
         body: >