|
@@ -8,20 +8,16 @@
|
|
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|
|
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
-import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.xcontent.ParseField;
|
|
|
-import org.elasticsearch.xcontent.ToXContent;
|
|
|
-import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.xcontent.XContentParser;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.Objects;
|
|
|
import java.util.Optional;
|
|
|
|
|
|
-public class RobertaTokenizationUpdate implements TokenizationUpdate {
|
|
|
+public class RobertaTokenizationUpdate extends AbstractTokenizationUpdate {
|
|
|
public static final ParseField NAME = new ParseField(RobertaTokenization.NAME);
|
|
|
|
|
|
public static ConstructingObjectParser<RobertaTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
|
|
@@ -30,25 +26,19 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate {
|
|
|
);
|
|
|
|
|
|
static {
|
|
|
- PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
|
|
|
- PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
|
|
|
+ declareCommonParserFields(PARSER);
|
|
|
}
|
|
|
|
|
|
public static RobertaTokenizationUpdate fromXContent(XContentParser parser) {
|
|
|
return PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- private final Tokenization.Truncate truncate;
|
|
|
- private final Integer span;
|
|
|
-
|
|
|
public RobertaTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
|
|
|
- this.truncate = truncate;
|
|
|
- this.span = span;
|
|
|
+ super(truncate, span);
|
|
|
}
|
|
|
|
|
|
public RobertaTokenizationUpdate(StreamInput in) throws IOException {
|
|
|
- this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
|
|
|
- this.span = in.readOptionalInt();
|
|
|
+ super(in);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -58,12 +48,27 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate {
|
|
|
return robertaTokenization;
|
|
|
}
|
|
|
|
|
|
+ Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());
|
|
|
+
|
|
|
+ if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
|
|
|
+ // When truncate value is incompatible with span wipe out
|
|
|
+ // the existing span setting to avoid an invalid combination of settings.
|
|
|
+ // This avoids the user have to set span to the special unset value
|
|
|
+ return new RobertaTokenization(
|
|
|
+ robertaTokenization.withSpecialTokens(),
|
|
|
+ robertaTokenization.isAddPrefixSpace(),
|
|
|
+ robertaTokenization.maxSequenceLength(),
|
|
|
+ getTruncate(),
|
|
|
+ null
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
return new RobertaTokenization(
|
|
|
robertaTokenization.withSpecialTokens(),
|
|
|
robertaTokenization.isAddPrefixSpace(),
|
|
|
robertaTokenization.maxSequenceLength(),
|
|
|
- Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
|
|
|
- Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
|
|
|
+ Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
|
|
|
+ Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
|
|
|
);
|
|
|
}
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
@@ -73,50 +78,13 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
- public boolean isNoop() {
|
|
|
- return truncate == null && span == null;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
|
|
- builder.startObject();
|
|
|
- if (truncate != null) {
|
|
|
- builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
|
|
|
- }
|
|
|
- if (span != null) {
|
|
|
- builder.field(Tokenization.SPAN.getPreferredName(), span);
|
|
|
- }
|
|
|
- builder.endObject();
|
|
|
- return builder;
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
public String getWriteableName() {
|
|
|
return NAME.getPreferredName();
|
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
- public void writeTo(StreamOutput out) throws IOException {
|
|
|
- out.writeOptionalEnum(truncate);
|
|
|
- out.writeOptionalInt(span);
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
public String getName() {
|
|
|
return NAME.getPreferredName();
|
|
|
}
|
|
|
-
|
|
|
- @Override
|
|
|
- public boolean equals(Object o) {
|
|
|
- if (this == o) return true;
|
|
|
- if (o == null || getClass() != o.getClass()) return false;
|
|
|
- RobertaTokenizationUpdate that = (RobertaTokenizationUpdate) o;
|
|
|
- return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public int hashCode() {
|
|
|
- return Objects.hash(truncate, span);
|
|
|
- }
|
|
|
}
|