|
@@ -11,6 +11,7 @@ import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
|
+import org.elasticsearch.common.xcontent.ParseField;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
@@ -21,50 +22,58 @@ import java.io.IOException;
|
|
|
import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
import java.util.Objects;
|
|
|
+import java.util.Optional;
|
|
|
|
|
|
-public class SentimentAnalysisConfig implements NlpConfig {
|
|
|
+public class TextClassificationConfig implements NlpConfig {
|
|
|
|
|
|
- public static final String NAME = "sentiment_analysis";
|
|
|
+ public static final String NAME = "text_classification";
|
|
|
+ public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
|
|
|
|
|
- public static SentimentAnalysisConfig fromXContentStrict(XContentParser parser) {
|
|
|
+ public static TextClassificationConfig fromXContentStrict(XContentParser parser) {
|
|
|
return STRICT_PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- public static SentimentAnalysisConfig fromXContentLenient(XContentParser parser) {
|
|
|
+ public static TextClassificationConfig fromXContentLenient(XContentParser parser) {
|
|
|
return LENIENT_PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- private static final ConstructingObjectParser<SentimentAnalysisConfig, Void> STRICT_PARSER = createParser(false);
|
|
|
- private static final ConstructingObjectParser<SentimentAnalysisConfig, Void> LENIENT_PARSER = createParser(true);
|
|
|
+ private static final ConstructingObjectParser<TextClassificationConfig, Void> STRICT_PARSER = createParser(false);
|
|
|
+ private static final ConstructingObjectParser<TextClassificationConfig, Void> LENIENT_PARSER = createParser(true);
|
|
|
|
|
|
@SuppressWarnings({ "unchecked"})
|
|
|
- private static ConstructingObjectParser<SentimentAnalysisConfig, Void> createParser(boolean ignoreUnknownFields) {
|
|
|
- ConstructingObjectParser<SentimentAnalysisConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
|
|
- a -> new SentimentAnalysisConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
|
|
|
+ private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
|
|
|
+ ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
|
|
+ a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
|
|
|
parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
|
|
|
parser.declareNamedObject(
|
|
|
ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
|
|
TOKENIZATION
|
|
|
);
|
|
|
parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
|
|
|
+ parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
|
|
return parser;
|
|
|
}
|
|
|
|
|
|
private final VocabularyConfig vocabularyConfig;
|
|
|
private final Tokenization tokenization;
|
|
|
private final List<String> classificationLabels;
|
|
|
+ private final int numTopClasses;
|
|
|
|
|
|
- public SentimentAnalysisConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization,
|
|
|
- @Nullable List<String> classificationLabels) {
|
|
|
+ public TextClassificationConfig(VocabularyConfig vocabularyConfig,
|
|
|
+ @Nullable Tokenization tokenization,
|
|
|
+ @Nullable List<String> classificationLabels,
|
|
|
+ @Nullable Integer numTopClasses) {
|
|
|
this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
|
|
|
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
|
|
|
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
|
|
|
+ this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
|
|
|
}
|
|
|
|
|
|
- public SentimentAnalysisConfig(StreamInput in) throws IOException {
|
|
|
+ public TextClassificationConfig(StreamInput in) throws IOException {
|
|
|
vocabularyConfig = new VocabularyConfig(in);
|
|
|
tokenization = in.readNamedWriteable(Tokenization.class);
|
|
|
classificationLabels = in.readStringList();
|
|
|
+ numTopClasses = in.readInt();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -72,6 +81,7 @@ public class SentimentAnalysisConfig implements NlpConfig {
|
|
|
vocabularyConfig.writeTo(out);
|
|
|
out.writeNamedWriteable(tokenization);
|
|
|
out.writeStringCollection(classificationLabels);
|
|
|
+ out.writeInt(numTopClasses);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -82,6 +92,7 @@ public class SentimentAnalysisConfig implements NlpConfig {
|
|
|
if (classificationLabels.isEmpty() == false) {
|
|
|
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
|
|
}
|
|
|
+ builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
@@ -111,15 +122,16 @@ public class SentimentAnalysisConfig implements NlpConfig {
|
|
|
if (o == this) return true;
|
|
|
if (o == null || getClass() != o.getClass()) return false;
|
|
|
|
|
|
- SentimentAnalysisConfig that = (SentimentAnalysisConfig) o;
|
|
|
+ TextClassificationConfig that = (TextClassificationConfig) o;
|
|
|
return Objects.equals(vocabularyConfig, that.vocabularyConfig)
|
|
|
&& Objects.equals(tokenization, that.tokenization)
|
|
|
+ && Objects.equals(numTopClasses, that.numTopClasses)
|
|
|
&& Objects.equals(classificationLabels, that.classificationLabels);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(vocabularyConfig, tokenization, classificationLabels);
|
|
|
+ return Objects.hash(vocabularyConfig, tokenization, classificationLabels, numTopClasses);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -136,6 +148,10 @@ public class SentimentAnalysisConfig implements NlpConfig {
|
|
|
return classificationLabels;
|
|
|
}
|
|
|
|
|
|
+ public int getNumTopClasses() {
|
|
|
+ return numTopClasses;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public boolean isAllocateOnly() {
|
|
|
return true;
|