Browse Source

[ML] Adds the class_assignment_objective parameter to classification (#52763)

Adds a new parameter for classification that enables choosing whether to assign labels to
maximise accuracy or to maximise the minimum class recall.

Fixes #52427.
Tom Veasey 5 years ago
parent
commit
58340c2dbe
17 changed files with 250 additions and 32 deletions
  1. 46 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java
  2. 2 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  3. 4 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  4. 1 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java
  5. 2 1
      docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc
  6. 4 0
      docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc
  7. 8 0
      docs/reference/ml/ml-shared.asciidoc
  8. 48 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java
  9. 1 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java
  10. 3 0
      x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json
  11. 2 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java
  12. 38 16
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java
  13. 4 3
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java
  14. 3 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml
  15. 25 0
      x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/mixed_cluster/90_ml_data_frame_analytics_crud.yml
  16. 33 0
      x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml
  17. 26 0
      x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/upgraded_cluster/90_ml_data_frame_analytics_crud.yml

+ 46 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

@@ -22,10 +22,12 @@ import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
+import java.util.Locale;
 import java.util.Objects;
 
 public class Classification implements DataFrameAnalysis {
@@ -49,6 +51,7 @@ public class Classification implements DataFrameAnalysis {
     static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
     static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
     static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
+    static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
     static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
     static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
 
@@ -67,7 +70,8 @@ public class Classification implements DataFrameAnalysis {
                 (String) a[7],
                 (Double) a[8],
                 (Integer) a[9],
-                (Long) a[10]));
+                (Long) a[10],
+                (ClassAssignmentObjective) a[11]));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -81,6 +85,12 @@ public class Classification implements DataFrameAnalysis {
         PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
         PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
         PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
+        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
+            if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                return ClassAssignmentObjective.fromString(p.text());
+            }
+            throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
+        }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
     }
 
     private final String dependentVariable;
@@ -92,13 +102,15 @@ public class Classification implements DataFrameAnalysis {
     private final Integer numTopFeatureImportanceValues;
     private final String predictionFieldName;
     private final Double trainingPercent;
+    private final ClassAssignmentObjective classAssignmentObjective;
     private final Integer numTopClasses;
     private final Long randomizeSeed;
 
     private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
                            @Nullable Integer maxTrees, @Nullable Double featureBagFraction,
                            @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
-                           @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
+                           @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
+                           @Nullable ClassAssignmentObjective classAssignmentObjective) {
         this.dependentVariable = Objects.requireNonNull(dependentVariable);
         this.lambda = lambda;
         this.gamma = gamma;
@@ -108,6 +120,7 @@ public class Classification implements DataFrameAnalysis {
         this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
         this.predictionFieldName = predictionFieldName;
         this.trainingPercent = trainingPercent;
+        this.classAssignmentObjective = classAssignmentObjective;
         this.numTopClasses = numTopClasses;
         this.randomizeSeed = randomizeSeed;
     }
@@ -157,6 +170,10 @@ public class Classification implements DataFrameAnalysis {
         return randomizeSeed;
     }
 
+    public ClassAssignmentObjective getClassAssignmentObjective() {
+        return classAssignmentObjective;
+    }
+
     public Integer getNumTopClasses() {
         return numTopClasses;
     }
@@ -192,6 +209,9 @@ public class Classification implements DataFrameAnalysis {
         if (randomizeSeed != null) {
             builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
         }
+        if (classAssignmentObjective != null) {
+            builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
+        }
         if (numTopClasses != null) {
             builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
         }
@@ -202,7 +222,7 @@ public class Classification implements DataFrameAnalysis {
     @Override
     public int hashCode() {
         return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
-            predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
+            predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
     }
 
     @Override
@@ -220,7 +240,8 @@ public class Classification implements DataFrameAnalysis {
             && Objects.equals(predictionFieldName, that.predictionFieldName)
             && Objects.equals(trainingPercent, that.trainingPercent)
             && Objects.equals(randomizeSeed, that.randomizeSeed)
-            && Objects.equals(numTopClasses, that.numTopClasses);
+            && Objects.equals(numTopClasses, that.numTopClasses)
+            && Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
     }
 
     @Override
@@ -228,6 +249,19 @@ public class Classification implements DataFrameAnalysis {
         return Strings.toString(this);
     }
 
+    public enum ClassAssignmentObjective {
+        MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
+
+        public static ClassAssignmentObjective fromString(String value) {
+            return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
+    }
+
     public static class Builder {
         private String dependentVariable;
         private Double lambda;
@@ -240,6 +274,7 @@ public class Classification implements DataFrameAnalysis {
         private Double trainingPercent;
         private Integer numTopClasses;
         private Long randomizeSeed;
+        private ClassAssignmentObjective classAssignmentObjective;
 
         private Builder(String dependentVariable) {
             this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -295,9 +330,15 @@ public class Classification implements DataFrameAnalysis {
             return this;
         }
 
+        public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) {
+            this.classAssignmentObjective = classAssignmentObjective;
+            return this;
+        }
+
         public Classification build() {
             return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
-                numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
+                numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
+                classAssignmentObjective);
         }
     }
 }

+ 2 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -1336,6 +1336,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 .setPredictionFieldName("my_dependent_variable_prediction")
                 .setTrainingPercent(80.0)
                 .setRandomizeSeed(42L)
+                .setClassAssignmentObjective(
+                    org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)
                 .setNumTopClasses(1)
                 .setLambda(1.0)
                 .setGamma(1.0)

+ 4 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.client.ml.datafeed.DatafeedStats;
 import org.elasticsearch.client.ml.datafeed.DatafeedUpdate;
 import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig;
+import org.elasticsearch.client.ml.dataframe.Classification;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest;
@@ -2969,7 +2970,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             // end::put-data-frame-analytics-outlier-detection-customized
 
             // tag::put-data-frame-analytics-classification
-            DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
+            DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1>
                 .setLambda(1.0) // <2>
                 .setGamma(5.5) // <3>
                 .setEta(5.5) // <4>
@@ -2979,7 +2980,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setPredictionFieldName("my_prediction_field_name") // <8>
                 .setTrainingPercent(50.0) // <9>
                 .setRandomizeSeed(1234L) // <10>
-                .setNumTopClasses(1) // <11>
+                .setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
+                .setNumTopClasses(1) // <12>
                 .build();
             // end::put-data-frame-analytics-classification
 

+ 1 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

@@ -36,6 +36,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
             .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
             .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
             .setRandomizeSeed(randomBoolean() ? null : randomLong())
+            .setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
             .setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
             .build();
     }

+ 2 - 1
docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

@@ -121,7 +121,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
 <8> The name of the prediction field in the results object.
 <9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
 <10> The seed to be used by the random generator that picks which rows are used in training.
-<11> The number of top classes to be reported in the results. Defaults to 2.
+<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
+<12> The number of top classes to be reported in the results. Defaults to 2.
 
 ===== Regression
 

+ 4 - 0
docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

@@ -136,6 +136,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
 (Optional, double) 
 include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
 
+`analysis`.`classification`.`class_assignment_objective`::::
+(Optional, string)
+include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective]
+
 `analysis`.`classification`.`num_top_classes`::::
 (Optional, integer)
 include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes]

+ 8 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -339,6 +339,14 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=mode]
 include::{docdir}/ml/ml-shared.asciidoc[tag=time-span]
 end::chunking-config[]
 
+tag::class-assignment-objective[]
+Defines the objective to optimize when assigning class labels. Available
+objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing
+accuracy class labels are chosen to maximize the number of correct predictions.
+When maximizing minimum recall labels are chosen to maximize the minimum recall
+for any class. Defaults to maximize_minimum_recall.
+end::class-assignment-objective[]
+
 tag::custom-rules[]
 An array of custom rule objects, which enable you to customize the way detectors
 operate. For example, a rule may dictate to the detector conditions under which

+ 48 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -12,6 +12,7 @@ import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.mapper.FieldAliasMapper;
@@ -21,6 +22,7 @@ import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
@@ -37,6 +39,7 @@ public class Classification implements DataFrameAnalysis {
 
     public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
     public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
+    public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
     public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
     public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
     public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
@@ -54,12 +57,19 @@ public class Classification implements DataFrameAnalysis {
                 (String) a[0],
                 new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
                 (String) a[7],
-                (Integer) a[8],
-                (Double) a[9],
-                (Long) a[10]));
+                (ClassAssignmentObjective) a[8],
+                (Integer) a[9],
+                (Double) a[10],
+                (Long) a[11]));
         parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
         BoostedTreeParams.declareFields(parser);
         parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
+        parser.declareField(optionalConstructorArg(), p -> {
+            if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                return ClassAssignmentObjective.fromString(p.text());
+            }
+            throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
+        }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
         parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
         parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
         parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
@@ -89,6 +99,7 @@ public class Classification implements DataFrameAnalysis {
     private final String dependentVariable;
     private final BoostedTreeParams boostedTreeParams;
     private final String predictionFieldName;
+    private final ClassAssignmentObjective classAssignmentObjective;
     private final int numTopClasses;
     private final double trainingPercent;
     private final long randomizeSeed;
@@ -96,6 +107,7 @@ public class Classification implements DataFrameAnalysis {
     public Classification(String dependentVariable,
                           BoostedTreeParams boostedTreeParams,
                           @Nullable String predictionFieldName,
+                          @Nullable ClassAssignmentObjective classAssignmentObjective,
                           @Nullable Integer numTopClasses,
                           @Nullable Double trainingPercent,
                           @Nullable Long randomizeSeed) {
@@ -108,19 +120,26 @@ public class Classification implements DataFrameAnalysis {
         this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
         this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
         this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
+        this.classAssignmentObjective = classAssignmentObjective == null ?
+            ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective;
         this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
         this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
         this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
     }
 
     public Classification(String dependentVariable) {
-        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null);
+        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
     }
 
     public Classification(StreamInput in) throws IOException {
         dependentVariable = in.readString();
         boostedTreeParams = new BoostedTreeParams(in);
         predictionFieldName = in.readOptionalString();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            classAssignmentObjective = in.readEnum(ClassAssignmentObjective.class);
+        } else {
+            classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL;
+        }
         numTopClasses = in.readOptionalVInt();
         trainingPercent = in.readDouble();
         if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
@@ -142,6 +161,10 @@ public class Classification implements DataFrameAnalysis {
         return predictionFieldName;
     }
 
+    public ClassAssignmentObjective getClassAssignmentObjective() {
+        return classAssignmentObjective;
+    }
+
     public int getNumTopClasses() {
         return numTopClasses;
     }
@@ -164,6 +187,9 @@ public class Classification implements DataFrameAnalysis {
         out.writeString(dependentVariable);
         boostedTreeParams.writeTo(out);
         out.writeOptionalString(predictionFieldName);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeEnum(classAssignmentObjective);
+        }
         out.writeOptionalVInt(numTopClasses);
         out.writeDouble(trainingPercent);
         if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
@@ -178,6 +204,7 @@ public class Classification implements DataFrameAnalysis {
         builder.startObject();
         builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
         boostedTreeParams.toXContent(builder, params);
+        builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
         builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
         if (predictionFieldName != null) {
             builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
@@ -195,6 +222,7 @@ public class Classification implements DataFrameAnalysis {
         Map<String, Object> params = new HashMap<>();
         params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
         params.putAll(boostedTreeParams.getParams());
+        params.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
         params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
         if (predictionFieldName != null) {
             params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
@@ -303,6 +331,7 @@ public class Classification implements DataFrameAnalysis {
         return Objects.equals(dependentVariable, that.dependentVariable)
             && Objects.equals(boostedTreeParams, that.boostedTreeParams)
             && Objects.equals(predictionFieldName, that.predictionFieldName)
+            && Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
             && Objects.equals(numTopClasses, that.numTopClasses)
             && trainingPercent == that.trainingPercent
             && randomizeSeed == that.randomizeSeed;
@@ -310,6 +339,20 @@ public class Classification implements DataFrameAnalysis {
 
     @Override
     public int hashCode() {
-        return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed);
+        return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
+                            numTopClasses, trainingPercent, randomizeSeed);
+    }
+
+    public enum ClassAssignmentObjective {
+        MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
+
+        public static ClassAssignmentObjective fromString(String value) {
+            return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
     }
 }

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java

@@ -317,6 +317,7 @@ public final class ReservedFieldNames {
             Classification.NAME.getPreferredName(),
             Classification.DEPENDENT_VARIABLE.getPreferredName(),
             Classification.PREDICTION_FIELD_NAME.getPreferredName(),
+            Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(),
             Classification.NUM_TOP_CLASSES.getPreferredName(),
             Classification.TRAINING_PERCENT.getPreferredName(),
             BoostedTreeParams.LAMBDA.getPreferredName(),

+ 3 - 0
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json

@@ -43,6 +43,9 @@
               "max_trees" : {
                 "type" : "integer"
               },
+              "class_assignment_objective" : {
+                "type" : "keyword"
+              },
               "num_top_classes" : {
                 "type" : "integer"
               },

+ 2 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java

@@ -155,12 +155,14 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
             bwcAnalysis = new Classification(bwcClassification.getDependentVariable(),
                 bwcClassification.getBoostedTreeParams(),
                 bwcClassification.getPredictionFieldName(),
+                bwcClassification.getClassAssignmentObjective(),
                 bwcClassification.getNumTopClasses(),
                 bwcClassification.getTrainingPercent(),
                 42L);
             testAnalysis = new Classification(testClassification.getDependentVariable(),
                 testClassification.getBoostedTreeParams(),
                 testClassification.getPredictionFieldName(),
+                testClassification.getClassAssignmentObjective(),
                 testClassification.getNumTopClasses(),
                 testClassification.getTrainingPercent(),
                 42L);

+ 38 - 16
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -54,17 +54,20 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         String dependentVariableName = randomAlphaOfLength(10);
         BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
         String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
+        Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
+            null : randomFrom(Classification.ClassAssignmentObjective.values());
         Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
         Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
         Long randomizeSeed = randomBoolean() ? null : randomLong();
-        return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent,
-            randomizeSeed);
+        return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
+            numTopClasses, trainingPercent, randomizeSeed);
     }
 
     public static Classification mutateForVersion(Classification instance, Version version) {
         return new Classification(instance.getDependentVariable(),
             BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
             instance.getPredictionFieldName(),
+            version.onOrAfter(Version.V_8_0_0) ? instance.getClassAssignmentObjective() : null,
             instance.getNumTopClasses(),
             instance.getTrainingPercent(),
             instance.getRandomizeSeed());
@@ -80,12 +83,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
             bwcSerializedObject.getBoostedTreeParams(),
             bwcSerializedObject.getPredictionFieldName(),
+            bwcSerializedObject.getClassAssignmentObjective(),
             bwcSerializedObject.getNumTopClasses(),
             bwcSerializedObject.getTrainingPercent(),
             42L);
         Classification newInstance = new Classification(testInstance.getDependentVariable(),
             testInstance.getBoostedTreeParams(),
             testInstance.getPredictionFieldName(),
+            testInstance.getClassAssignmentObjective(),
             testInstance.getNumTopClasses(),
             testInstance.getTrainingPercent(),
             42L);
@@ -99,71 +104,85 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
 
     public void testConstructor_GivenTrainingPercentIsLessThanOne() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong()));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
     }
 
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong()));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
     }
 
     public void testConstructor_GivenNumTopClassesIsLessThanZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong()));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
 
         assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
     }
 
     public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong()));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
 
         assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
     }
 
     public void testGetPredictionFieldName() {
-        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
         assertThat(classification.getPredictionFieldName(), equalTo("result"));
 
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
         assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
     }
 
+    public void testClassAssignmentObjective() {
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
+            Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
+        assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
+
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
+        Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
+        assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
+
+        // class_assignment_objective == null, default applied
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
+        assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
+    }
+
     public void testGetNumTopClasses() {
-        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong());
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
         assertThat(classification.getNumTopClasses(), equalTo(7));
 
         // Boundary condition: num_top_classes == 0
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
         assertThat(classification.getNumTopClasses(), equalTo(0));
 
         // Boundary condition: num_top_classes == 1000
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
         assertThat(classification.getNumTopClasses(), equalTo(1000));
 
         // num_top_classes == null, default applied
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
         assertThat(classification.getNumTopClasses(), equalTo(2));
     }
 
     public void testGetTrainingPercent() {
-        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
         assertThat(classification.getTrainingPercent(), equalTo(50.0));
 
         // Boundary condition: training_percent == 1.0
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
         assertThat(classification.getTrainingPercent(), equalTo(1.0));
 
         // Boundary condition: training_percent == 100.0
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
         assertThat(classification.getTrainingPercent(), equalTo(100.0));
 
         // training_percent == null, default applied
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong());
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
         assertThat(classification.getTrainingPercent(), equalTo(100.0));
     }
 
@@ -178,6 +197,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             equalTo(
                 Map.of(
                     "dependent_variable", "foo",
+                    "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
                     "num_top_classes", 2,
                     "prediction_field_name", "foo_prediction",
                     "prediction_field_type", "bool")));
@@ -186,6 +206,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             equalTo(
                 Map.of(
                     "dependent_variable", "bar",
+                    "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
                     "num_top_classes", 2,
                     "prediction_field_name", "bar_prediction",
                     "prediction_field_type", "int")));
@@ -194,6 +215,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             equalTo(
                 Map.of(
                     "dependent_variable", "baz",
+                    "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
                     "num_top_classes", 2,
                     "prediction_field_name", "baz_prediction",
                     "prediction_field_type", "string")));

+ 4 - 3
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -88,6 +88,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 null,
                 null,
                 null,
+                null,
                 null));
         registerAnalytics(config);
         putAnalytics(config);
@@ -189,7 +190,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 sourceIndex,
                 destIndex,
                 null,
-                new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null));
+                new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null));
         registerAnalytics(config);
         putAnalytics(config);
 
@@ -438,7 +439,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             .build();
 
         DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
-            new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
+            new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null));
         registerAnalytics(firstJob);
         putAnalytics(firstJob);
 
@@ -447,7 +448,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
         long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
         DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
-            new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed));
+            new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed));
 
         registerAnalytics(secondJob);
         putAnalytics(secondJob);

+ 3 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

@@ -1834,6 +1834,7 @@ setup:
                 "eta": 0.5,
                 "max_trees": 400,
                 "feature_bag_fraction": 0.3,
+                "class_assignment_objective": "maximize_accuracy",
                 "training_percent": 60.3,
                 "randomize_seed": 24
               }
@@ -1853,6 +1854,7 @@ setup:
       "prediction_field_name": "foo_prediction",
       "training_percent": 60.3,
       "randomize_seed": 24,
+      "class_assignment_objective": "maximize_accuracy",
       "num_top_classes": 2
     }
   }}
@@ -1896,6 +1898,7 @@ setup:
       "prediction_field_name": "foo_prediction",
       "training_percent": 100.0,
       "randomize_seed": 24,
+      "class_assignment_objective": "maximize_minimum_recall",
       "num_top_classes": 2
     }
   }}

+ 25 - 0
x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/mixed_cluster/90_ml_data_frame_analytics_crud.yml

@@ -92,6 +92,31 @@
   - match: { data_frame_analytics.0.id: "old_cluster_regression_job" }
   - match: { data_frame_analytics.0.state: "stopped" }
 
+---
+"Get old classification job":
+
+  - do:
+      ml.get_data_frame_analytics:
+        id: "old_cluster_classification_job"
+  - match: { count: 1 }
+  - match: { data_frame_analytics.0.id: "old_cluster_classification_job" }
+  - match: { data_frame_analytics.0.source.index: ["bwc_ml_classification_job_source"] }
+  - match: { data_frame_analytics.0.source.query: {"term": { "user.keyword": "Kimchy" }} }
+  - match: { data_frame_analytics.0.dest.index: "old_cluster_classification_job_results" }
+  - match: { data_frame_analytics.0.analysis.classification.dependent_variable: "foo" }
+  - match: { data_frame_analytics.0.analysis.classification.training_percent: 100.0 }
+  - is_true: data_frame_analytics.0.analysis.classification.randomize_seed
+
+---
+"Get old classification job stats":
+
+  - do:
+      ml.get_data_frame_analytics_stats:
+        id: "old_cluster_classification_job"
+  - match: { count: 1 }
+  - match: { data_frame_analytics.0.id: "old_cluster_classification_job" }
+  - match: { data_frame_analytics.0.state: "stopped" }
+
 ---
 "Put an outlier_detection job on the mixed cluster":
 

+ 33 - 0
x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml

@@ -19,6 +19,16 @@ setup:
             "user": "Kimchy"
           }
 
+  - do:
+      index:
+        index: bwc_ml_classification_job_source
+        body: >
+          {
+            "numeric_field_1": 1.0,
+            "foo": "a",
+            "user": "Kimchy"
+          }
+
   - do:
       indices.refresh:
         index: bwc_ml_*
@@ -64,3 +74,26 @@ setup:
             }
           }
   - match: { id: "old_cluster_regression_job" }
+
+---
+"Put classification job on the old cluster":
+
+  - do:
+      ml.put_data_frame_analytics:
+        id: "old_cluster_classification_job"
+        body: >
+          {
+            "source": {
+              "index": "bwc_ml_classification_job_source",
+              "query": {"term" : { "user.keyword" : "Kimchy" }}
+            },
+            "dest": {
+              "index": "old_cluster_classification_job_results"
+            },
+            "analysis": {
+              "classification":{
+                "dependent_variable": "foo"
+              }
+            }
+          }
+  - match: { id: "old_cluster_classification_job" }

+ 26 - 0
x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/upgraded_cluster/90_ml_data_frame_analytics_crud.yml

@@ -52,6 +52,32 @@
   - match: { data_frame_analytics.0.id: "old_cluster_regression_job" }
   - match: { data_frame_analytics.0.state: "stopped" }
 
+---
+"Get old classification job":
+
+  - do:
+      ml.get_data_frame_analytics:
+        id: "old_cluster_classification_job"
+  - match: { count: 1 }
+  - match: { data_frame_analytics.0.id: "old_cluster_classification_job" }
+  - match: { data_frame_analytics.0.source.index: ["bwc_ml_classification_job_source"] }
+  - match: { data_frame_analytics.0.source.query: {"term": { "user.keyword": "Kimchy" }} }
+  - match: { data_frame_analytics.0.dest.index: "old_cluster_classification_job_results" }
+  - match: { data_frame_analytics.0.analysis.classification.dependent_variable: "foo" }
+  - match: { data_frame_analytics.0.analysis.classification.training_percent: 100.0 }
+  - match: { data_frame_analytics.0.analysis.classification.class_assignment_objective: "maximize_minimum_recall" }
+  - is_true: data_frame_analytics.0.analysis.classification.randomize_seed
+
+---
+"Get old classification job stats":
+
+  - do:
+      ml.get_data_frame_analytics_stats:
+        id: "old_cluster_classification_job"
+  - match: { count: 1 }
+  - match: { data_frame_analytics.0.id: "old_cluster_classification_job" }
+  - match: { data_frame_analytics.0.state: "stopped" }
+
 ---
 "Get mixed cluster outlier_detection job":
   - skip: