Browse Source

[ML] Allow NLP truncate option to be updated when span is set (#91224)

David Kyle 2 years ago
parent
commit
defa765bc8

+ 5 - 0
docs/changelog/91224.yaml

@@ -0,0 +1,5 @@
+pr: 91224
+summary: Allow NLP truncate option to be updated when span is set
+area: Machine Learning
+type: bug
+issues: []

+ 92 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractTokenizationUpdate.java

@@ -0,0 +1,92 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public abstract class AbstractTokenizationUpdate implements TokenizationUpdate {
+
+    private final Tokenization.Truncate truncate;
+    private final Integer span;
+
+    protected static void declareCommonParserFields(ConstructingObjectParser<? extends AbstractTokenizationUpdate, Void> parser) {
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
+        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
+    }
+
+    public AbstractTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
+        this.truncate = truncate;
+        this.span = span;
+    }
+
+    public AbstractTokenizationUpdate(StreamInput in) throws IOException {
+        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
+        if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
+            this.span = in.readOptionalInt();
+        } else {
+            this.span = null;
+        }
+    }
+
+    @Override
+    public boolean isNoop() {
+        return truncate == null && span == null;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (truncate != null) {
+            builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
+        }
+        if (span != null) {
+            builder.field(Tokenization.SPAN.getPreferredName(), span);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalEnum(truncate);
+        if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
+            out.writeOptionalInt(span);
+        }
+    }
+
+    public Integer getSpan() {
+        return span;
+    }
+
+    public Tokenization.Truncate getTruncate() {
+        return truncate;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o instanceof AbstractTokenizationUpdate == false) {
+            return false;
+        }
+        AbstractTokenizationUpdate that = (AbstractTokenizationUpdate) o;
+        return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(truncate, span);
+    }
+}

+ 21 - 59
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java

@@ -7,21 +7,17 @@
 
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
-import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
-import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.Objects;
 import java.util.Optional;
 
-public class BertTokenizationUpdate implements TokenizationUpdate {
+public class BertTokenizationUpdate extends AbstractTokenizationUpdate {
 
     public static final ParseField NAME = BertTokenization.NAME;
 
@@ -31,29 +27,19 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
     );
 
     static {
-        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
-        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
+        declareCommonParserFields(PARSER);
     }
 
     public static BertTokenizationUpdate fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    private final Tokenization.Truncate truncate;
-    private final Integer span;
-
     public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
-        this.truncate = truncate;
-        this.span = span;
+        super(truncate, span);
     }
 
     public BertTokenizationUpdate(StreamInput in) throws IOException {
-        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
-        if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
-            this.span = in.readOptionalInt();
-        } else {
-            this.span = null;
-        }
+        super(in);
     }
 
     @Override
@@ -66,65 +52,41 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
             );
         }
 
+        Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());
+
         if (isNoop()) {
             return originalConfig;
         }
 
+        if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
+            // When truncate value is incompatible with span wipe out
+            // the existing span setting to avoid an invalid combination of settings.
+            // This avoids the user have to set span to the special unset value
+            return new BertTokenization(
+                originalConfig.doLowerCase(),
+                originalConfig.withSpecialTokens(),
+                originalConfig.maxSequenceLength(),
+                getTruncate(),
+                null
+            );
+        }
+
         return new BertTokenization(
             originalConfig.doLowerCase(),
             originalConfig.withSpecialTokens(),
             originalConfig.maxSequenceLength(),
-            Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
-            Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
+            Optional.ofNullable(getTruncate()).orElse(originalConfig.getTruncate()),
+            Optional.ofNullable(getSpan()).orElse(originalConfig.getSpan())
         );
     }
 
-    @Override
-    public boolean isNoop() {
-        return truncate == null && span == null;
-    }
-
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-        if (truncate != null) {
-            builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
-        }
-        if (span != null) {
-            builder.field(Tokenization.SPAN.getPreferredName(), span);
-        }
-        builder.endObject();
-        return builder;
-    }
-
     @Override
     public String getWriteableName() {
         return BertTokenization.NAME.getPreferredName();
     }
 
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeOptionalEnum(truncate);
-        if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
-            out.writeOptionalInt(span);
-        }
-    }
-
     @Override
     public String getName() {
         return BertTokenization.NAME.getPreferredName();
     }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        BertTokenizationUpdate that = (BertTokenizationUpdate) o;
-        return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(truncate, span);
-    }
 }

+ 19 - 59
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java

@@ -7,21 +7,17 @@
 
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
-import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
-import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.Objects;
 import java.util.Optional;
 
-public class MPNetTokenizationUpdate implements TokenizationUpdate {
+public class MPNetTokenizationUpdate extends AbstractTokenizationUpdate {
 
     public static final ParseField NAME = MPNetTokenization.NAME;
 
@@ -31,29 +27,19 @@ public class MPNetTokenizationUpdate implements TokenizationUpdate {
     );
 
     static {
-        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
-        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
+        declareCommonParserFields(PARSER);
     }
 
     public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    private final Tokenization.Truncate truncate;
-    private final Integer span;
-
     public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
-        this.truncate = truncate;
-        this.span = span;
+        super(truncate, span);
     }
 
     public MPNetTokenizationUpdate(StreamInput in) throws IOException {
-        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
-        if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
-            this.span = in.readOptionalInt();
-        } else {
-            this.span = null;
-        }
+        super(in);
     }
 
     @Override
@@ -70,61 +56,35 @@ public class MPNetTokenizationUpdate implements TokenizationUpdate {
             return originalConfig;
         }
 
+        if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
+            // When truncate value is incompatible with span wipe out
+            // the existing span setting to avoid an invalid combination of settings.
+            // This avoids the user have to set span to the special unset value
+            return new MPNetTokenization(
+                originalConfig.doLowerCase(),
+                originalConfig.withSpecialTokens(),
+                originalConfig.maxSequenceLength(),
+                getTruncate(),
+                null
+            );
+        }
+
         return new MPNetTokenization(
             originalConfig.doLowerCase(),
             originalConfig.withSpecialTokens(),
             originalConfig.maxSequenceLength(),
-            Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
-            Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
+            Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
+            Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
         );
     }
 
-    @Override
-    public boolean isNoop() {
-        return truncate == null && span == null;
-    }
-
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-        if (truncate != null) {
-            builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
-        }
-        if (span != null) {
-            builder.field(Tokenization.SPAN.getPreferredName(), span);
-        }
-        builder.endObject();
-        return builder;
-    }
-
     @Override
     public String getWriteableName() {
         return MPNetTokenization.NAME.getPreferredName();
     }
 
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeOptionalEnum(truncate);
-        if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
-            out.writeOptionalInt(span);
-        }
-    }
-
     @Override
     public String getName() {
         return MPNetTokenization.NAME.getPreferredName();
     }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
-        return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(truncate, span);
-    }
 }

+ 21 - 53
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdate.java

@@ -8,20 +8,16 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
-import org.elasticsearch.xcontent.ToXContent;
-import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.Objects;
 import java.util.Optional;
 
-public class RobertaTokenizationUpdate implements TokenizationUpdate {
+public class RobertaTokenizationUpdate extends AbstractTokenizationUpdate {
     public static final ParseField NAME = new ParseField(RobertaTokenization.NAME);
 
     public static ConstructingObjectParser<RobertaTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
@@ -30,25 +26,19 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate {
     );
 
     static {
-        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
-        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
+        declareCommonParserFields(PARSER);
     }
 
     public static RobertaTokenizationUpdate fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    private final Tokenization.Truncate truncate;
-    private final Integer span;
-
     public RobertaTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
-        this.truncate = truncate;
-        this.span = span;
+        super(truncate, span);
     }
 
     public RobertaTokenizationUpdate(StreamInput in) throws IOException {
-        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
-        this.span = in.readOptionalInt();
+        super(in);
     }
 
     @Override
@@ -58,12 +48,27 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate {
                 return robertaTokenization;
             }
 
+            Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());
+
+            if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
+                // When truncate value is incompatible with span wipe out
+                // the existing span setting to avoid an invalid combination of settings.
+                // This avoids the user have to set span to the special unset value
+                return new RobertaTokenization(
+                    robertaTokenization.withSpecialTokens(),
+                    robertaTokenization.isAddPrefixSpace(),
+                    robertaTokenization.maxSequenceLength(),
+                    getTruncate(),
+                    null
+                );
+            }
+
             return new RobertaTokenization(
                 robertaTokenization.withSpecialTokens(),
                 robertaTokenization.isAddPrefixSpace(),
                 robertaTokenization.maxSequenceLength(),
-                Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
-                Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
+                Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
+                Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
             );
         }
         throw ExceptionsHelper.badRequestException(
@@ -73,50 +78,13 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate {
         );
     }
 
-    @Override
-    public boolean isNoop() {
-        return truncate == null && span == null;
-    }
-
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
-        builder.startObject();
-        if (truncate != null) {
-            builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
-        }
-        if (span != null) {
-            builder.field(Tokenization.SPAN.getPreferredName(), span);
-        }
-        builder.endObject();
-        return builder;
-    }
-
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();
     }
 
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeOptionalEnum(truncate);
-        out.writeOptionalInt(span);
-    }
-
     @Override
     public String getName() {
         return NAME.getPreferredName();
     }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        RobertaTokenizationUpdate that = (RobertaTokenizationUpdate) o;
-        return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(truncate, span);
-    }
 }

+ 29 - 18
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java

@@ -27,7 +27,16 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
     public enum Truncate {
         FIRST,
         SECOND,
-        NONE;
+        NONE {
+            @Override
+            public boolean isInCompatibleWithSpan() {
+                return false;
+            }
+        };
+
+        public boolean isInCompatibleWithSpan() {
+            return true;
+        }
 
         public static Truncate fromString(String value) {
             return valueOf(value.toUpperCase(Locale.ROOT));
@@ -50,7 +59,7 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
     private static final boolean DEFAULT_DO_LOWER_CASE = false;
     private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true;
     private static final Truncate DEFAULT_TRUNCATION = Truncate.FIRST;
-    private static final int DEFAULT_SPAN = -1;
+    private static final int UNSET_SPAN_VALUE = -1;
 
     static <T extends Tokenization> void declareCommonFields(ConstructingObjectParser<T, ?> parser) {
         parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE);
@@ -61,7 +70,7 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
     }
 
     public static BertTokenization createDefault() {
-        return new BertTokenization(null, null, null, Tokenization.DEFAULT_TRUNCATION, DEFAULT_SPAN);
+        return new BertTokenization(null, null, null, Tokenization.DEFAULT_TRUNCATION, UNSET_SPAN_VALUE);
     }
 
     protected final boolean doLowerCase;
@@ -84,10 +93,14 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
         this.withSpecialTokens = Optional.ofNullable(withSpecialTokens).orElse(DEFAULT_WITH_SPECIAL_TOKENS);
         this.maxSequenceLength = Optional.ofNullable(maxSequenceLength).orElse(DEFAULT_MAX_SEQUENCE_LENGTH);
         this.truncate = Optional.ofNullable(truncate).orElse(DEFAULT_TRUNCATION);
-        this.span = Optional.ofNullable(span).orElse(DEFAULT_SPAN);
-        if (this.span < 0 && this.span != -1) {
+        this.span = Optional.ofNullable(span).orElse(UNSET_SPAN_VALUE);
+        if (this.span < 0 && this.span != UNSET_SPAN_VALUE) {
             throw new IllegalArgumentException(
-                "[" + SPAN.getPreferredName() + "] must be non-negative to indicate span length or -1 to indicate no windowing should occur"
+                "["
+                    + SPAN.getPreferredName()
+                    + "] must be non-negative to indicate span length or ["
+                    + UNSET_SPAN_VALUE
+                    + "] to indicate no windowing should occur"
             );
         }
         if (this.span > this.maxSequenceLength) {
@@ -103,17 +116,7 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
                     + "]"
             );
         }
-        if (this.span != -1 && truncate != Truncate.NONE) {
-            throw new IllegalArgumentException(
-                "["
-                    + SPAN.getPreferredName()
-                    + "] must not be provided when ["
-                    + TRUNCATE.getPreferredName()
-                    + "] is not ["
-                    + Truncate.NONE
-                    + "]"
-            );
-        }
+        validateSpanAndTruncate(truncate, span);
     }
 
     public Tokenization(StreamInput in) throws IOException {
@@ -124,7 +127,7 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
         if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
             this.span = in.readInt();
         } else {
-            this.span = -1;
+            this.span = UNSET_SPAN_VALUE;
         }
     }
 
@@ -154,6 +157,14 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
         return builder;
     }
 
+    public static void validateSpanAndTruncate(@Nullable Truncate truncate, @Nullable Integer span) {
+        if ((span != null && span != UNSET_SPAN_VALUE) && (truncate != null && truncate.isInCompatibleWithSpan())) {
+            throw new IllegalArgumentException(
+                "[" + SPAN.getPreferredName() + "] must not be provided when [" + TRUNCATE.getPreferredName() + "] is [" + truncate + "]"
+            );
+        }
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;

+ 78 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdateTests.java

@@ -0,0 +1,78 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+
+import static org.hamcrest.Matchers.sameInstance;
+
+public class BertTokenizationUpdateTests extends AbstractBWCWireSerializationTestCase<BertTokenizationUpdate> {
+
+    public static BertTokenizationUpdate randomInstance() {
+        Integer span = randomBoolean() ? null : randomIntBetween(8, 128);
+        Tokenization.Truncate truncate = randomBoolean() ? null : randomFrom(Tokenization.Truncate.values());
+
+        if (truncate != Tokenization.Truncate.NONE) {
+            span = null;
+        }
+        return new BertTokenizationUpdate(truncate, span);
+    }
+
+    public void testApply() {
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> new BertTokenizationUpdate(Tokenization.Truncate.SECOND, 100).apply(BertTokenizationTests.createRandom())
+        );
+
+        var updatedSpan = new BertTokenizationUpdate(null, 100).apply(
+            new BertTokenization(false, false, 512, Tokenization.Truncate.NONE, 50)
+        );
+        assertEquals(new BertTokenization(false, false, 512, Tokenization.Truncate.NONE, 100), updatedSpan);
+
+        var updatedTruncate = new BertTokenizationUpdate(Tokenization.Truncate.FIRST, null).apply(
+            new BertTokenization(true, true, 512, Tokenization.Truncate.SECOND, null)
+        );
+        assertEquals(new BertTokenization(true, true, 512, Tokenization.Truncate.FIRST, null), updatedTruncate);
+
+        var updatedNone = new BertTokenizationUpdate(Tokenization.Truncate.NONE, null).apply(
+            new BertTokenization(true, true, 512, Tokenization.Truncate.SECOND, null)
+        );
+        assertEquals(new BertTokenization(true, true, 512, Tokenization.Truncate.NONE, null), updatedNone);
+
+        var unmodified = new BertTokenization(true, true, 512, Tokenization.Truncate.NONE, null);
+        assertThat(new BertTokenizationUpdate(null, null).apply(unmodified), sameInstance(unmodified));
+    }
+
+    public void testNoop() {
+        assertTrue(new BertTokenizationUpdate(null, null).isNoop());
+        assertFalse(new BertTokenizationUpdate(Tokenization.Truncate.SECOND, null).isNoop());
+        assertFalse(new BertTokenizationUpdate(null, 10).isNoop());
+        assertFalse(new BertTokenizationUpdate(Tokenization.Truncate.NONE, 10).isNoop());
+    }
+
+    @Override
+    protected Writeable.Reader<BertTokenizationUpdate> instanceReader() {
+        return BertTokenizationUpdate::new;
+    }
+
+    @Override
+    protected BertTokenizationUpdate createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected BertTokenizationUpdate mutateInstanceForVersion(BertTokenizationUpdate instance, Version version) {
+        if (version.before(Version.V_8_2_0)) {
+            return new BertTokenizationUpdate(instance.getTruncate(), null);
+        }
+
+        return instance;
+    }
+}

+ 71 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdateTests.java

@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+
+import static org.hamcrest.Matchers.sameInstance;
+
+public class MPNetTokenizationUpdateTests extends AbstractBWCWireSerializationTestCase<MPNetTokenizationUpdate> {
+
+    public static MPNetTokenizationUpdate randomInstance() {
+        Integer span = randomBoolean() ? null : randomIntBetween(8, 128);
+        Tokenization.Truncate truncate = randomBoolean() ? null : randomFrom(Tokenization.Truncate.values());
+
+        if (truncate != Tokenization.Truncate.NONE) {
+            span = null;
+        }
+        return new MPNetTokenizationUpdate(truncate, span);
+    }
+
+    public void testApply() {
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> new MPNetTokenizationUpdate(Tokenization.Truncate.SECOND, 100).apply(MPNetTokenizationTests.createRandom())
+        );
+
+        var updatedSpan = new MPNetTokenizationUpdate(null, 100).apply(
+            new MPNetTokenization(false, false, 512, Tokenization.Truncate.NONE, 50)
+        );
+        assertEquals(new MPNetTokenization(false, false, 512, Tokenization.Truncate.NONE, 100), updatedSpan);
+
+        var updatedTruncate = new MPNetTokenizationUpdate(Tokenization.Truncate.FIRST, null).apply(
+            new MPNetTokenization(true, true, 512, Tokenization.Truncate.SECOND, null)
+        );
+        assertEquals(new MPNetTokenization(true, true, 512, Tokenization.Truncate.FIRST, null), updatedTruncate);
+
+        var updatedNone = new MPNetTokenizationUpdate(Tokenization.Truncate.NONE, null).apply(
+            new MPNetTokenization(true, true, 512, Tokenization.Truncate.SECOND, null)
+        );
+        assertEquals(new MPNetTokenization(true, true, 512, Tokenization.Truncate.NONE, null), updatedNone);
+
+        var unmodified = new MPNetTokenization(true, true, 512, Tokenization.Truncate.NONE, null);
+        assertThat(new MPNetTokenizationUpdate(null, null).apply(unmodified), sameInstance(unmodified));
+    }
+
+    @Override
+    protected Writeable.Reader<MPNetTokenizationUpdate> instanceReader() {
+        return MPNetTokenizationUpdate::new;
+    }
+
+    @Override
+    protected MPNetTokenizationUpdate createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected MPNetTokenizationUpdate mutateInstanceForVersion(MPNetTokenizationUpdate instance, Version version) {
+        if (version.before(Version.V_8_2_0)) {
+            return new MPNetTokenizationUpdate(instance.getTruncate(), null);
+        }
+
+        return instance;
+    }
+}

+ 71 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdateTests.java

@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+
+import static org.hamcrest.Matchers.sameInstance;
+
+public class RobertaTokenizationUpdateTests extends AbstractBWCWireSerializationTestCase<RobertaTokenizationUpdate> {
+
+    public static RobertaTokenizationUpdate randomInstance() {
+        Integer span = randomBoolean() ? null : randomIntBetween(8, 128);
+        Tokenization.Truncate truncate = randomBoolean() ? null : randomFrom(Tokenization.Truncate.values());
+
+        if (truncate != Tokenization.Truncate.NONE) {
+            span = null;
+        }
+        return new RobertaTokenizationUpdate(truncate, span);
+    }
+
+    public void testApply() {
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> new RobertaTokenizationUpdate(Tokenization.Truncate.SECOND, 100).apply(RobertaTokenizationTests.createRandom())
+        );
+
+        var updatedSpan = new RobertaTokenizationUpdate(null, 100).apply(
+            new RobertaTokenization(false, false, 512, Tokenization.Truncate.NONE, 50)
+        );
+        assertEquals(new RobertaTokenization(false, false, 512, Tokenization.Truncate.NONE, 100), updatedSpan);
+
+        var updatedTruncate = new RobertaTokenizationUpdate(Tokenization.Truncate.FIRST, null).apply(
+            new RobertaTokenization(true, true, 512, Tokenization.Truncate.SECOND, null)
+        );
+        assertEquals(new RobertaTokenization(true, true, 512, Tokenization.Truncate.FIRST, null), updatedTruncate);
+
+        var updatedNone = new RobertaTokenizationUpdate(Tokenization.Truncate.NONE, null).apply(
+            new RobertaTokenization(true, true, 512, Tokenization.Truncate.SECOND, null)
+        );
+        assertEquals(new RobertaTokenization(true, true, 512, Tokenization.Truncate.NONE, null), updatedNone);
+
+        var unmodified = new RobertaTokenization(true, true, 512, Tokenization.Truncate.NONE, null);
+        assertThat(new RobertaTokenizationUpdate(null, null).apply(unmodified), sameInstance(unmodified));
+    }
+
+    @Override
+    protected Writeable.Reader<RobertaTokenizationUpdate> instanceReader() {
+        return RobertaTokenizationUpdate::new;
+    }
+
+    @Override
+    protected RobertaTokenizationUpdate createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected RobertaTokenizationUpdate mutateInstanceForVersion(RobertaTokenizationUpdate instance, Version version) {
+        if (version.before(Version.V_8_2_0)) {
+            return new RobertaTokenizationUpdate(instance.getTruncate(), null);
+        }
+
+        return instance;
+    }
+}