|
@@ -9,15 +9,18 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|
|
|
|
|
import org.elasticsearch.TransportVersion;
|
|
|
import org.elasticsearch.Version;
|
|
|
+import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.xcontent.ObjectParser;
|
|
|
+import org.elasticsearch.xcontent.ParseField;
|
|
|
import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.xcontent.XContentParser;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
|
|
+import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.Objects;
|
|
@@ -26,6 +29,7 @@ import java.util.Optional;
|
|
|
public class FillMaskConfig implements NlpConfig {
|
|
|
|
|
|
public static final String NAME = "fill_mask";
|
|
|
+ public static final String MASK_TOKEN = "mask_token";
|
|
|
public static final int DEFAULT_NUM_RESULTS = 5;
|
|
|
|
|
|
public static FillMaskConfig fromXContentStrict(XContentParser parser) {
|
|
@@ -36,6 +40,7 @@ public class FillMaskConfig implements NlpConfig {
|
|
|
return LENIENT_PARSER.apply(parser, null).build();
|
|
|
}
|
|
|
|
|
|
+ private static final ParseField MASK_TOKEN_FIELD = new ParseField(MASK_TOKEN);
|
|
|
private static final ObjectParser<FillMaskConfig.Builder, Void> STRICT_PARSER = createParser(false);
|
|
|
private static final ObjectParser<FillMaskConfig.Builder, Void> LENIENT_PARSER = createParser(true);
|
|
|
|
|
@@ -57,6 +62,7 @@ public class FillMaskConfig implements NlpConfig {
|
|
|
);
|
|
|
parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
|
|
|
parser.declareString(Builder::setResultsField, RESULTS_FIELD);
|
|
|
+ parser.declareString(Builder::setMaskToken, MASK_TOKEN_FIELD);
|
|
|
return parser;
|
|
|
}
|
|
|
|
|
@@ -101,6 +107,9 @@ public class FillMaskConfig implements NlpConfig {
|
|
|
if (resultsField != null) {
|
|
|
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
|
|
|
}
|
|
|
+ if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
|
|
|
+ builder.field(MASK_TOKEN_FIELD.getPreferredName(), tokenization.getMaskToken());
|
|
|
+ }
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
@@ -182,8 +191,9 @@ public class FillMaskConfig implements NlpConfig {
|
|
|
public static class Builder {
|
|
|
private VocabularyConfig vocabularyConfig;
|
|
|
private Tokenization tokenization;
|
|
|
- private int numTopClasses;
|
|
|
+ private Integer numTopClasses;
|
|
|
private String resultsField;
|
|
|
+ private String maskToken;
|
|
|
|
|
|
Builder() {}
|
|
|
|
|
@@ -214,8 +224,27 @@ public class FillMaskConfig implements NlpConfig {
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
- public FillMaskConfig build() {
|
|
|
+ public FillMaskConfig.Builder setMaskToken(String maskToken) {
|
|
|
+ this.maskToken = maskToken;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public FillMaskConfig build() throws IllegalArgumentException {
|
|
|
+ if (tokenization == null) {
|
|
|
+ tokenization = Tokenization.createDefault();
|
|
|
+ }
|
|
|
+ validateMaskToken(tokenization.getMaskToken());
|
|
|
return new FillMaskConfig(vocabularyConfig, tokenization, numTopClasses, resultsField);
|
|
|
}
|
|
|
+
|
|
|
+ private void validateMaskToken(String tokenizationMaskToken) throws IllegalArgumentException {
|
|
|
+ if (maskToken != null) {
|
|
|
+ if (maskToken.equals(tokenizationMaskToken) == false) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ Strings.format("Mask token requested was [%s] but must be [%s] for this model", maskToken, tokenizationMaskToken)
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|