Browse Source

[ML] generalize pytorch sentiment analysis to text classification (#77084)

* [ML] generalize pytorch sentiment analysis to text classification

* Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java
Benjamin Trent 4 years ago
parent
commit
0e1efa6533

+ 3 - 3
docs/reference/ml/df-analytics/apis/infer-trained-model-deployment.asciidoc

@@ -59,8 +59,8 @@ The input text for evaluation.
 [[infer-trained-model-deployment-example]]
 == {api-examples-title}
 
-The response depends on the task the model is trained for. If it is a 
-sentiment analysis task, the response is the score. For example:
+The response depends on the task the model is trained for. If it is a
+text classification task, the response is the score. For example:
 
 [source,console]
 --------------------------------------------------
@@ -77,7 +77,7 @@ The API returns scores in this case, for example:
 ----
 {
   "positive" : 0.9998062667902223,
-  "negative" : 1.9373320977752957E-4	
+  "negative" : 1.9373320977752957E-4
 }
 ----
 // NOTCONSOLE

+ 8 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -25,7 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.SentimentAnalysisResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertPassThroughConfig;
@@ -44,7 +44,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
@@ -172,9 +172,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(FillMaskConfig.NAME),
             FillMaskConfig::fromXContentStrict));
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
-            new ParseField(SentimentAnalysisConfig.NAME), SentimentAnalysisConfig::fromXContentLenient));
-        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(SentimentAnalysisConfig.NAME),
-            SentimentAnalysisConfig::fromXContentStrict));
+            new ParseField(TextClassificationConfig.NAME), TextClassificationConfig::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class,
+            new ParseField(TextClassificationConfig.NAME), TextClassificationConfig::fromXContentStrict));
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
             new ParseField(BertPassThroughConfig.NAME), BertPassThroughConfig::fromXContentLenient));
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(BertPassThroughConfig.NAME),
@@ -269,8 +269,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             PyTorchPassThroughResults.NAME,
             PyTorchPassThroughResults::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
-            SentimentAnalysisResults.NAME,
-            SentimentAnalysisResults::new));
+            TextClassificationResults.NAME,
+            TextClassificationResults::new));
 
         // Inference Configs
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
@@ -282,7 +282,7 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
             FillMaskConfig.NAME, FillMaskConfig::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
-            SentimentAnalysisConfig.NAME, SentimentAnalysisConfig::new));
+            TextClassificationConfig.NAME, TextClassificationConfig::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
             BertPassThroughConfig.NAME, BertPassThroughConfig::new));
 

+ 0 - 107
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SentimentAnalysisResults.java

@@ -1,107 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.ml.inference.results;
-
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.XContentBuilder;
-
-import java.io.IOException;
-import java.util.LinkedHashMap;
-import java.util.Map;
-import java.util.Objects;
-
-public class SentimentAnalysisResults implements InferenceResults {
-
-    public static final String NAME = "sentiment_analysis_result";
-
-    private final String class1Label;
-    private final String class2Label;
-    private final double class1Score;
-    private final double class2Score;
-
-    public SentimentAnalysisResults(String class1Label, double class1Score,
-                                    String class2Label, double class2Score) {
-        this.class1Label = class1Label;
-        this.class1Score = class1Score;
-        this.class2Label = class2Label;
-        this.class2Score = class2Score;
-    }
-
-    public SentimentAnalysisResults(StreamInput in) throws IOException {
-        class1Label = in.readString();
-        class1Score = in.readDouble();
-        class2Label = in.readString();
-        class2Score = in.readDouble();
-    }
-
-    public String getClass1Label() {
-        return class1Label;
-    }
-
-    public double getClass1Score() {
-        return class1Score;
-    }
-
-    public String getClass2Label() {
-        return class2Label;
-    }
-
-    public double getClass2Score() {
-        return class2Score;
-    }
-
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.field(class1Label, class1Score);
-        builder.field(class2Label, class2Score);
-        return builder;
-    }
-
-    @Override
-    public String getWriteableName() {
-        return NAME;
-    }
-
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeString(class1Label);
-        out.writeDouble(class1Score);
-        out.writeString(class2Label);
-        out.writeDouble(class2Score);
-    }
-
-    @Override
-    public Map<String, Object> asMap() {
-        Map<String, Object> map = new LinkedHashMap<>();
-        map.put(class1Label, class1Score);
-        map.put(class2Label, class2Score);
-        return map;
-    }
-
-    @Override
-    public Object predictedValue() {
-        return class1Score;
-    }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        SentimentAnalysisResults that = (SentimentAnalysisResults) o;
-        return Double.compare(that.class1Score, class1Score) == 0 &&
-            Double.compare(that.class2Score, class2Score) == 0 &&
-            Objects.equals(this.class1Label, that.class1Label) &&
-            Objects.equals(this.class2Label, that.class2Label);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(class1Label, class1Score, class2Label, class2Score);
-    }
-}

+ 76 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextClassificationResults.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+public class TextClassificationResults implements InferenceResults {
+
+    public static final String NAME = "text_classification_result";
+
+    private final List<TopClassEntry> entryList;
+
+    public TextClassificationResults(List<TopClassEntry> entryList) {
+        this.entryList = entryList;
+    }
+
+    public TextClassificationResults(StreamInput in) throws IOException {
+        entryList = in.readList(TopClassEntry::new);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.mapContents(asMap());
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeList(entryList);
+    }
+
+    @Override
+    public Map<String, Object> asMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        for (TopClassEntry entry : entryList) {
+            map.put(entry.getClassification().toString(), entry.getScore());
+        }
+        return map;
+    }
+
+    @Override
+    public Object predictedValue() {
+        return entryList.get(0).getScore();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TextClassificationResults that = (TextClassificationResults) o;
+        return Objects.equals(that.entryList, entryList);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(entryList);
+    }
+}

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TopClassEntry.java

@@ -66,6 +66,10 @@ public class TopClassEntry implements Writeable, ToXContentObject {
     private final double probability;
     private final double score;
 
+    public TopClassEntry(Object classification, double score) {
+        this(classification, score, score);
+    }
+
     public TopClassEntry(Object classification, double probability, double score) {
         this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
         this.probability = probability;

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java

@@ -6,6 +6,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.common.logging.LoggerMessageFormat;
 import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -23,6 +24,10 @@ public class WarningInferenceResults implements InferenceResults {
 
     private final String warning;
 
+    public WarningInferenceResults(String warning, Object... args) {
+        this(LoggerMessageFormat.format(warning, args));
+    }
+
     public WarningInferenceResults(String warning) {
         this.warning = warning;
     }

+ 30 - 14
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SentimentAnalysisConfig.java → x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java

@@ -11,6 +11,7 @@ import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
@@ -21,50 +22,58 @@ import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
-public class SentimentAnalysisConfig implements NlpConfig {
+public class TextClassificationConfig implements NlpConfig {
 
-    public static final String NAME = "sentiment_analysis";
+    public static final String NAME = "text_classification";
+    public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
 
-    public static SentimentAnalysisConfig fromXContentStrict(XContentParser parser) {
+    public static TextClassificationConfig fromXContentStrict(XContentParser parser) {
         return STRICT_PARSER.apply(parser, null);
     }
 
-    public static SentimentAnalysisConfig fromXContentLenient(XContentParser parser) {
+    public static TextClassificationConfig fromXContentLenient(XContentParser parser) {
         return LENIENT_PARSER.apply(parser, null);
     }
 
-    private static final ConstructingObjectParser<SentimentAnalysisConfig, Void> STRICT_PARSER = createParser(false);
-    private static final ConstructingObjectParser<SentimentAnalysisConfig, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<TextClassificationConfig, Void> STRICT_PARSER = createParser(false);
+    private static final ConstructingObjectParser<TextClassificationConfig, Void> LENIENT_PARSER = createParser(true);
 
     @SuppressWarnings({ "unchecked"})
-    private static ConstructingObjectParser<SentimentAnalysisConfig, Void> createParser(boolean ignoreUnknownFields) {
-        ConstructingObjectParser<SentimentAnalysisConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new SentimentAnalysisConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
+    private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
+            a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
         parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
         parser.declareNamedObject(
             ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
                 TOKENIZATION
         );
         parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
+        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
     private final Tokenization tokenization;
     private final List<String> classificationLabels;
+    private final int numTopClasses;
 
-    public SentimentAnalysisConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization,
-                                   @Nullable List<String> classificationLabels) {
+    public TextClassificationConfig(VocabularyConfig vocabularyConfig,
+                                    @Nullable Tokenization tokenization,
+                                    @Nullable List<String> classificationLabels,
+                                    @Nullable Integer numTopClasses) {
         this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
         this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
+        this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
     }
 
-    public SentimentAnalysisConfig(StreamInput in) throws IOException {
+    public TextClassificationConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
         tokenization = in.readNamedWriteable(Tokenization.class);
         classificationLabels = in.readStringList();
+        numTopClasses = in.readInt();
     }
 
     @Override
@@ -72,6 +81,7 @@ public class SentimentAnalysisConfig implements NlpConfig {
         vocabularyConfig.writeTo(out);
         out.writeNamedWriteable(tokenization);
         out.writeStringCollection(classificationLabels);
+        out.writeInt(numTopClasses);
     }
 
     @Override
@@ -82,6 +92,7 @@ public class SentimentAnalysisConfig implements NlpConfig {
         if (classificationLabels.isEmpty() == false) {
             builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
         }
+        builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
         builder.endObject();
         return builder;
     }
@@ -111,15 +122,16 @@ public class SentimentAnalysisConfig implements NlpConfig {
         if (o == this) return true;
         if (o == null || getClass() != o.getClass()) return false;
 
-        SentimentAnalysisConfig that = (SentimentAnalysisConfig) o;
+        TextClassificationConfig that = (TextClassificationConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
             && Objects.equals(tokenization, that.tokenization)
+            && Objects.equals(numTopClasses, that.numTopClasses)
             && Objects.equals(classificationLabels, that.classificationLabels);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization, classificationLabels);
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, numTopClasses);
     }
 
     @Override
@@ -136,6 +148,10 @@ public class SentimentAnalysisConfig implements NlpConfig {
         return classificationLabels;
     }
 
+    public int getNumTopClasses() {
+        return numTopClasses;
+    }
+
     @Override
     public boolean isAllocateOnly() {
         return true;

+ 16 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/SentimentAnalysisResultsTests.java → x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextClassificationResultsTests.java

@@ -10,25 +10,34 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.hasSize;
 
-public class SentimentAnalysisResultsTests extends AbstractWireSerializingTestCase<SentimentAnalysisResults> {
+public class TextClassificationResultsTests extends AbstractWireSerializingTestCase<TextClassificationResults> {
     @Override
-    protected Writeable.Reader<SentimentAnalysisResults> instanceReader() {
-        return SentimentAnalysisResults::new;
+    protected Writeable.Reader<TextClassificationResults> instanceReader() {
+        return TextClassificationResults::new;
     }
 
     @Override
-    protected SentimentAnalysisResults createTestInstance() {
-        return new SentimentAnalysisResults(randomAlphaOfLength(6), randomDouble(),
-            randomAlphaOfLength(6), randomDouble());
+    protected TextClassificationResults createTestInstance() {
+        return new TextClassificationResults(
+            Stream.generate(TopClassEntryTests::createRandomTopClassEntry).limit(randomIntBetween(2, 5)).collect(Collectors.toList())
+        );
     }
 
     public void testAsMap() {
-        SentimentAnalysisResults testInstance = new SentimentAnalysisResults("foo", 1.0, "bar", 0.0);
+        TextClassificationResults testInstance = new TextClassificationResults(
+            List.of(
+                new TopClassEntry("foo", 1.0),
+                new TopClassEntry("bar", 0.0)
+            )
+        );
         Map<String, Object> asMap = testInstance.asMap();
         assertThat(asMap.keySet(), hasSize(2));
         assertThat(1.0, closeTo((Double)asMap.get("foo"), 0.0001));

+ 11 - 10
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SentimentAnalysisConfigTests.java → x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java

@@ -15,7 +15,7 @@ import org.junit.Before;
 
 import java.io.IOException;
 
-public class SentimentAnalysisConfigTests extends InferenceConfigItemTestCase<SentimentAnalysisConfig> {
+public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {
 
     private boolean lenient;
 
@@ -25,32 +25,33 @@ public class SentimentAnalysisConfigTests extends InferenceConfigItemTestCase<Se
     }
 
     @Override
-    protected SentimentAnalysisConfig doParseInstance(XContentParser parser) throws IOException {
-        return lenient ? SentimentAnalysisConfig.fromXContentLenient(parser) : SentimentAnalysisConfig.fromXContentStrict(parser);
+    protected TextClassificationConfig doParseInstance(XContentParser parser) throws IOException {
+        return lenient ? TextClassificationConfig.fromXContentLenient(parser) : TextClassificationConfig.fromXContentStrict(parser);
     }
 
     @Override
-    protected Writeable.Reader<SentimentAnalysisConfig> instanceReader() {
-        return SentimentAnalysisConfig::new;
+    protected Writeable.Reader<TextClassificationConfig> instanceReader() {
+        return TextClassificationConfig::new;
     }
 
     @Override
-    protected SentimentAnalysisConfig createTestInstance() {
+    protected TextClassificationConfig createTestInstance() {
         return createRandom();
     }
 
     @Override
-    protected SentimentAnalysisConfig mutateInstanceForVersion(SentimentAnalysisConfig instance, Version version) {
+    protected TextClassificationConfig mutateInstanceForVersion(TextClassificationConfig instance, Version version) {
         return instance;
     }
 
-    public static SentimentAnalysisConfig createRandom() {
-        return new SentimentAnalysisConfig(
+    public static TextClassificationConfig createRandom() {
+        return new TextClassificationConfig(
             VocabularyConfigTests.createRandom(),
             randomBoolean() ?
                 null :
                 randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
-            randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10))
+            randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomIntBetween(-1, 10)
         );
     }
 }

+ 3 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java

@@ -11,7 +11,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertPassThroughCon
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 
 import java.util.Locale;
@@ -24,10 +24,10 @@ public enum TaskType {
             return new NerProcessor(tokenizer, (NerConfig) config);
         }
     },
-    SENTIMENT_ANALYSIS {
+    TEXT_CLASSIFICATION {
         @Override
         public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
-            return new SentimentAnalysisProcessor(tokenizer, (SentimentAnalysisConfig) config);
+            return new TextClassificationProcessor(tokenizer, (TextClassificationConfig) config);
         }
     },
     FILL_MASK {

+ 48 - 18
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessor.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -7,45 +7,67 @@
 
 package org.elasticsearch.xpack.ml.inference.nlp;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.SentimentAnalysisResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Comparator;
 import java.util.List;
 import java.util.Locale;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
-public class SentimentAnalysisProcessor implements NlpTask.Processor {
+public class TextClassificationProcessor implements NlpTask.Processor {
 
     private final NlpTokenizer tokenizer;
-    private final List<String> classLabels;
+    private final String[] classLabels;
+    private final int numTopClasses;
 
-    SentimentAnalysisProcessor(NlpTokenizer tokenizer, SentimentAnalysisConfig config) {
+    TextClassificationProcessor(NlpTokenizer tokenizer, TextClassificationConfig config) {
         this.tokenizer = tokenizer;
         List<String> classLabels = config.getClassificationLabels();
         if (classLabels == null || classLabels.isEmpty()) {
-            this.classLabels = List.of("negative", "positive");
+            this.classLabels = new String[] {"negative", "positive"};
         } else {
-            this.classLabels = classLabels;
+            this.classLabels = classLabels.toArray(String[]::new);
         }
-
+        // negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size
+        // This is a nice way of setting the value
+        this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
         validate();
     }
 
     private void validate() {
-        if (classLabels.size() != 2) {
+        if (classLabels.length < 2) {
             throw new ValidationException().addValidationError(
-                String.format(Locale.ROOT, "Sentiment analysis requires exactly 2 [%s]. Invalid labels %s",
-                    SentimentAnalysisConfig.CLASSIFICATION_LABELS, classLabels)
+                String.format(
+                    Locale.ROOT,
+                    "Text classification requires at least 2 [%s]. Invalid labels [%s]",
+                    TextClassificationConfig.CLASSIFICATION_LABELS,
+                    Strings.arrayToCommaDelimitedString(classLabels)
+                )
+            );
+        }
+        if (numTopClasses == 0) {
+            throw new ValidationException().addValidationError(
+                String.format(
+                    Locale.ROOT,
+                    "Text classification requires at least 1 [%s]; provided [%d]",
+                    TextClassificationConfig.NUM_TOP_CLASSES,
+                    numTopClasses
+                )
             );
         }
     }
@@ -72,18 +94,26 @@ public class SentimentAnalysisProcessor implements NlpTask.Processor {
 
     InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
         if (pyTorchResult.getInferenceResult().length < 1) {
-            return new WarningInferenceResults("Sentiment analysis result has no data");
+            return new WarningInferenceResults("Text classification result has no data");
         }
 
-        if (pyTorchResult.getInferenceResult()[0].length < 2) {
-            return new WarningInferenceResults("Expected 2 values in sentiment analysis result");
+        if (pyTorchResult.getInferenceResult()[0].length != classLabels.length) {
+            return new WarningInferenceResults(
+                "Expected exactly [{}] values in text classification result; got [{}]",
+                classLabels.length,
+                pyTorchResult.getInferenceResult()[0].length
+            );
         }
 
         double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
-        // the second score is usually the positive score so put that first
-        // so it comes first in the results doc
-        return new SentimentAnalysisResults(classLabels.get(1), normalizedScores[1],
-            classLabels.get(0), normalizedScores[0]);
+        return new TextClassificationResults(
+            IntStream.range(0, normalizedScores.length)
+                .mapToObj(i -> new TopClassEntry(classLabels[i], normalizedScores[i]))
+                // Put the highest scoring class first
+                .sorted(Comparator.comparing(TopClassEntry::getProbability).reversed())
+                .limit(numTopClasses)
+                .collect(Collectors.toList())
+        );
     }
 
     static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {

+ 36 - 19
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessorTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -14,7 +14,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
@@ -29,24 +29,25 @@ import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.mockito.Mockito.mock;
 
-public class SentimentAnalysisProcessorTests extends ESTestCase {
+public class TextClassificationProcessorTests extends ESTestCase {
 
     public void testInvalidResult() {
-        SentimentAnalysisConfig config = new SentimentAnalysisConfig(new VocabularyConfig("test-index", "vocab"), null, null);
-        SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(mock(BertTokenizer.class), config);
+        TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
+        TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config);
         {
-            PyTorchResult torchResult = new PyTorchResult("foo", new double[][]{}, 0L, null);
+            PyTorchResult torchResult = new PyTorchResult("foo", new double[][] {}, 0L, null);
             InferenceResults inferenceResults = processor.processResult(null, torchResult);
             assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
-            assertEquals("Sentiment analysis result has no data",
-                ((WarningInferenceResults) inferenceResults).getWarning());
+            assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning());
         }
         {
-            PyTorchResult torchResult = new PyTorchResult("foo", new double[][]{{1.0}}, 0L, null);
+            PyTorchResult torchResult = new PyTorchResult("foo", new double[][] { { 1.0 } }, 0L, null);
             InferenceResults inferenceResults = processor.processResult(null, torchResult);
             assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
-            assertEquals("Expected 2 values in sentiment analysis result",
-                ((WarningInferenceResults)inferenceResults).getWarning());
+            assertEquals(
+                "Expected exactly [2] values in text classification result; got [1]",
+                ((WarningInferenceResults) inferenceResults).getWarning()
+            );
         }
     }
 
@@ -56,8 +57,8 @@ public class SentimentAnalysisProcessorTests extends ESTestCase {
             new BertTokenization(null, null, 512)
         ).build();
 
-        SentimentAnalysisConfig config = new SentimentAnalysisConfig(new VocabularyConfig("test-index", "vocab"), null, null);
-        SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(tokenizer, config);
+        TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
+        TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
 
         NlpTask.Request request = processor.buildRequest("Elasticsearch fun", "request1");
 
@@ -70,14 +71,30 @@ public class SentimentAnalysisProcessorTests extends ESTestCase {
     }
 
     public void testValidate() {
+        ValidationException validationException = expectThrows(
+            ValidationException.class,
+            () -> new TextClassificationProcessor(
+                mock(BertTokenizer.class),
+                new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("too few"), null)
+            )
+        );
 
-        SentimentAnalysisConfig config = new SentimentAnalysisConfig(new VocabularyConfig("test-index", "vocab"), null,
-            List.of("too", "many", "class", "labels"));
+        assertThat(
+            validationException.getMessage(),
+            containsString("Text classification requires at least 2 [classification_labels]. Invalid labels [too few]")
+        );
 
-        ValidationException validationException = expectThrows(ValidationException.class,
-            () -> new SentimentAnalysisProcessor(mock(BertTokenizer.class), config));
+        validationException = expectThrows(
+            ValidationException.class,
+            () -> new TextClassificationProcessor(
+                mock(BertTokenizer.class),
+                new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, List.of("class", "labels"), 0)
+            )
+        );
 
-        assertThat(validationException.getMessage(),
-            containsString("Sentiment analysis requires exactly 2 [classification_labels]. Invalid labels [too, many, class, labels]"));
+        assertThat(
+            validationException.getMessage(),
+            containsString("Text classification requires at least 1 [num_top_classes]; provided [0]")
+        );
     }
- }
+}