|
@@ -10,8 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|
|
import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
-import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
|
-import org.elasticsearch.common.xcontent.ParseField;
|
|
|
+import org.elasticsearch.common.xcontent.ObjectParser;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
@@ -20,7 +19,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
import java.util.Objects;
|
|
|
import java.util.Optional;
|
|
@@ -28,25 +26,24 @@ import java.util.Optional;
|
|
|
public class TextClassificationConfig implements NlpConfig {
|
|
|
|
|
|
public static final String NAME = "text_classification";
|
|
|
- public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
|
|
|
|
|
public static TextClassificationConfig fromXContentStrict(XContentParser parser) {
|
|
|
- return STRICT_PARSER.apply(parser, null);
|
|
|
+ return STRICT_PARSER.apply(parser, null).build();
|
|
|
}
|
|
|
|
|
|
public static TextClassificationConfig fromXContentLenient(XContentParser parser) {
|
|
|
- return LENIENT_PARSER.apply(parser, null);
|
|
|
+ return LENIENT_PARSER.apply(parser, null).build();
|
|
|
}
|
|
|
|
|
|
- private static final ConstructingObjectParser<TextClassificationConfig, Void> STRICT_PARSER = createParser(false);
|
|
|
- private static final ConstructingObjectParser<TextClassificationConfig, Void> LENIENT_PARSER = createParser(true);
|
|
|
+ private static final ObjectParser<TextClassificationConfig.Builder, Void> STRICT_PARSER = createParser(false);
|
|
|
+ private static final ObjectParser<TextClassificationConfig.Builder, Void> LENIENT_PARSER = createParser(true);
|
|
|
+
|
|
|
+ private static ObjectParser<TextClassificationConfig.Builder, Void> createParser(boolean ignoreUnknownFields) {
|
|
|
+ ObjectParser<TextClassificationConfig.Builder, Void> parser =
|
|
|
+ new ObjectParser<>(NAME, ignoreUnknownFields, Builder::new);
|
|
|
|
|
|
- @SuppressWarnings({ "unchecked"})
|
|
|
- private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
|
|
|
- ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
|
|
|
- a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
|
|
|
parser.declareObject(
|
|
|
- ConstructingObjectParser.optionalConstructorArg(),
|
|
|
+ Builder::setVocabularyConfig,
|
|
|
(p, c) -> {
|
|
|
if (ignoreUnknownFields == false) {
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
@@ -59,11 +56,12 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
VOCABULARY
|
|
|
);
|
|
|
parser.declareNamedObject(
|
|
|
- ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
|
|
+ Builder::setTokenization, (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
|
|
|
TOKENIZATION
|
|
|
);
|
|
|
- parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
|
|
|
- parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
|
|
+ parser.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
|
|
+ parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
|
|
|
+ parser.declareString(Builder::setResultsField, RESULTS_FIELD);
|
|
|
return parser;
|
|
|
}
|
|
|
|
|
@@ -71,16 +69,32 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
private final Tokenization tokenization;
|
|
|
private final List<String> classificationLabels;
|
|
|
private final int numTopClasses;
|
|
|
+ private final String resultsField;
|
|
|
|
|
|
public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig,
|
|
|
@Nullable Tokenization tokenization,
|
|
|
- @Nullable List<String> classificationLabels,
|
|
|
- @Nullable Integer numTopClasses) {
|
|
|
+ List<String> classificationLabels,
|
|
|
+ @Nullable Integer numTopClasses,
|
|
|
+ @Nullable String resultsField) {
|
|
|
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
|
|
|
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
|
|
|
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
|
|
|
- this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
|
|
|
+ if (classificationLabels == null || classificationLabels.size() < 2) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] requires at least 2 [{}]; provided {}",
|
|
|
+ NAME, CLASSIFICATION_LABELS, classificationLabels);
|
|
|
+ }
|
|
|
+ this.classificationLabels = classificationLabels;
|
|
|
this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
|
|
|
+ if (this.numTopClasses == 0) {
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
+ "[{}] requires at least 1 [{}]; provided [{}]",
|
|
|
+ NAME,
|
|
|
+ TextClassificationConfig.NUM_TOP_CLASSES,
|
|
|
+ numTopClasses
|
|
|
+ );
|
|
|
+ }
|
|
|
+ this.resultsField = resultsField;
|
|
|
+
|
|
|
}
|
|
|
|
|
|
public TextClassificationConfig(StreamInput in) throws IOException {
|
|
@@ -88,6 +102,7 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
tokenization = in.readNamedWriteable(Tokenization.class);
|
|
|
classificationLabels = in.readStringList();
|
|
|
numTopClasses = in.readInt();
|
|
|
+ resultsField = in.readOptionalString();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -96,6 +111,7 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
out.writeNamedWriteable(tokenization);
|
|
|
out.writeStringCollection(classificationLabels);
|
|
|
out.writeInt(numTopClasses);
|
|
|
+ out.writeOptionalString(resultsField);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -103,10 +119,11 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
builder.startObject();
|
|
|
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
|
|
|
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
|
|
|
- if (classificationLabels.isEmpty() == false) {
|
|
|
- builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
|
|
- }
|
|
|
+ builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
|
|
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
|
|
+ if (resultsField != null) {
|
|
|
+ builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
|
|
|
+ }
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
@@ -140,12 +157,13 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
return Objects.equals(vocabularyConfig, that.vocabularyConfig)
|
|
|
&& Objects.equals(tokenization, that.tokenization)
|
|
|
&& Objects.equals(numTopClasses, that.numTopClasses)
|
|
|
- && Objects.equals(classificationLabels, that.classificationLabels);
|
|
|
+ && Objects.equals(classificationLabels, that.classificationLabels)
|
|
|
+ && Objects.equals(resultsField, that.resultsField);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(vocabularyConfig, tokenization, classificationLabels, numTopClasses);
|
|
|
+ return Objects.hash(vocabularyConfig, tokenization, classificationLabels, numTopClasses, resultsField);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -166,8 +184,63 @@ public class TextClassificationConfig implements NlpConfig {
|
|
|
return numTopClasses;
|
|
|
}
|
|
|
|
|
|
+ public String getResultsField() {
|
|
|
+ return resultsField;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public boolean isAllocateOnly() {
|
|
|
return true;
|
|
|
}
|
|
|
+
|
|
|
+ public static class Builder {
|
|
|
+ private VocabularyConfig vocabularyConfig;
|
|
|
+ private Tokenization tokenization;
|
|
|
+ private List<String> classificationLabels;
|
|
|
+ private int numTopClasses;
|
|
|
+ private String resultsField;
|
|
|
+
|
|
|
+ Builder() {}
|
|
|
+
|
|
|
+ Builder(TextClassificationConfig config) {
|
|
|
+ this.vocabularyConfig = config.vocabularyConfig;
|
|
|
+ this.tokenization = config.tokenization;
|
|
|
+ this.classificationLabels = config.classificationLabels;
|
|
|
+ this.numTopClasses = config.numTopClasses;
|
|
|
+ this.resultsField = config.resultsField;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setVocabularyConfig(VocabularyConfig vocabularyConfig) {
|
|
|
+ this.vocabularyConfig = vocabularyConfig;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setTokenization(Tokenization tokenization) {
|
|
|
+ this.tokenization = tokenization;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setClassificationLabels(List<String> classificationLabels) {
|
|
|
+ this.classificationLabels = classificationLabels;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setNumTopClasses(Integer numTopClasses) {
|
|
|
+ this.numTopClasses = numTopClasses;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setResultsField(String resultsField) {
|
|
|
+ this.resultsField = resultsField;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public TextClassificationConfig build() {
|
|
|
+ return new TextClassificationConfig(vocabularyConfig,
|
|
|
+ tokenization,
|
|
|
+ classificationLabels,
|
|
|
+ numTopClasses,
|
|
|
+ resultsField);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|