瀏覽代碼

[ML] Add ability to update the truncation option at inference (#80267)

The truncate options is set in an inference_config object
defined in the inference processor or _infer API
David Kyle 3 年之前
父節點
當前提交
19f0df7ae0
共有 29 個文件被更改,包括 928 次插入137 次删除
  1. 19 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  2. 111 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java
  3. 35 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdate.java
  4. 26 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdate.java
  5. 83 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java
  6. 31 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdate.java
  7. 52 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdate.java
  8. 29 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java
  9. 21 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TokenizationUpdate.java
  10. 42 18
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java
  11. 47 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java
  12. 42 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java
  13. 66 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdateTests.java
  14. 44 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java
  15. 58 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java
  16. 44 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java
  17. 58 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java
  18. 50 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  19. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  20. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  21. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  22. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
  23. 2 13
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  24. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  25. 12 7
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java
  26. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  27. 2 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  28. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java
  29. 39 22
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

+ 19 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResu
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenizationUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
@@ -55,6 +56,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassification
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
@@ -399,7 +401,7 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedXContent.add(
             new NamedXContentRegistry.Entry(
                 InferenceConfigUpdate.class,
-                new ParseField(TextClassificationConfig.NAME),
+                new ParseField(TextClassificationConfigUpdate.NAME),
                 TextClassificationConfigUpdate::fromXContentStrict
             )
         );
@@ -434,6 +436,14 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             )
         );
 
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                TokenizationUpdate.class,
+                BertTokenizationUpdate.NAME,
+                (p, c) -> BertTokenizationUpdate.fromXContent(p)
+            )
+        );
+
         return namedXContent;
     }
 
@@ -582,6 +592,14 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             new NamedWriteableRegistry.Entry(Tokenization.class, BertTokenization.NAME.getPreferredName(), BertTokenization::new)
         );
 
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                TokenizationUpdate.class,
+                BertTokenizationUpdate.NAME.getPreferredName(),
+                BertTokenizationUpdate::new
+            )
+        );
+
         return namedWriteables;
     }
 }

+ 111 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class BertTokenizationUpdate implements TokenizationUpdate {
+
+    public static final ParseField NAME = BertTokenization.NAME;
+
+    public static ConstructingObjectParser<BertTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
+        "bert_tokenization_update",
+        a -> new BertTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]))
+    );
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
+    }
+
+    public static BertTokenizationUpdate fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Tokenization.Truncate truncate;
+
+    public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate) {
+        this.truncate = truncate;
+    }
+
+    public BertTokenizationUpdate(StreamInput in) throws IOException {
+        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
+    }
+
+    @Override
+    public Tokenization apply(Tokenization originalConfig) {
+        if (isNoop()) {
+            return originalConfig;
+        }
+
+        if (originalConfig instanceof BertTokenization == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Tokenization config of type [{}] can not be updated with a request of type [{}]",
+                originalConfig.getName(),
+                getName()
+            );
+        }
+
+        return new BertTokenization(
+            originalConfig.doLowerCase(),
+            originalConfig.withSpecialTokens(),
+            originalConfig.maxSequenceLength(),
+            this.truncate
+        );
+    }
+
+    @Override
+    public boolean isNoop() {
+        return truncate == null;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return BertTokenization.NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalEnum(truncate);
+    }
+
+    @Override
+    public String getName() {
+        return BertTokenization.NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        BertTokenizationUpdate that = (BertTokenizationUpdate) o;
+        return truncate == that.truncate;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(truncate);
+    }
+}

+ 35 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdate.java

@@ -22,6 +22,7 @@ import java.util.Objects;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.NUM_TOP_CLASSES;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 
 public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
 
@@ -31,11 +32,12 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
         Map<String, Object> options = new HashMap<>(map);
         Integer numTopClasses = (Integer) options.remove(NUM_TOP_CLASSES.getPreferredName());
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
+        TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options);
 
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
         }
-        return new FillMaskConfigUpdate(numTopClasses, resultsField);
+        return new FillMaskConfigUpdate(numTopClasses, resultsField, tokenizationUpdate);
     }
 
     private static final ObjectParser<FillMaskConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
@@ -44,6 +46,11 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
         ObjectParser<FillMaskConfigUpdate.Builder, Void> parser = new ObjectParser<>(NAME, lenient, FillMaskConfigUpdate.Builder::new);
         parser.declareString(FillMaskConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
         parser.declareInt(FillMaskConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        parser.declareNamedObject(
+            FillMaskConfigUpdate.Builder::setTokenizationUpdate,
+            (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, lenient),
+            TOKENIZATION
+        );
         return parser;
     }
 
@@ -54,32 +61,33 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
     private final Integer numTopClasses;
     private final String resultsField;
 
-    public FillMaskConfigUpdate(Integer numTopClasses, String resultsField) {
+    public FillMaskConfigUpdate(Integer numTopClasses, String resultsField, TokenizationUpdate tokenizationUpdate) {
+        super(tokenizationUpdate);
         this.numTopClasses = numTopClasses;
         this.resultsField = resultsField;
     }
 
     public FillMaskConfigUpdate(StreamInput in) throws IOException {
+        super(in);
         this.numTopClasses = in.readOptionalInt();
         this.resultsField = in.readOptionalString();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
         out.writeOptionalInt(numTopClasses);
         out.writeOptionalString(resultsField);
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         if (numTopClasses != null) {
             builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
         }
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
-        builder.endObject();
         return builder;
     }
 
@@ -115,12 +123,17 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
         if (resultsField != null) {
             builder.setResultsField(resultsField);
         }
+        if (tokenizationUpdate != null) {
+            builder.setTokenization(tokenizationUpdate.apply(fillMaskConfig.getTokenization()));
+
+        }
         return builder.build();
     }
 
     boolean isNoop(FillMaskConfig originalConfig) {
         return (this.numTopClasses == null || this.numTopClasses == originalConfig.getNumTopClasses())
-            && (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()));
+            && (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()))
+            && super.isNoop();
     }
 
     @Override
@@ -133,9 +146,13 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
         return resultsField;
     }
 
+    public Integer getNumTopClasses() {
+        return numTopClasses;
+    }
+
     @Override
     public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
-        return new Builder().setNumTopClasses(numTopClasses).setResultsField(resultsField);
+        return new Builder().setNumTopClasses(numTopClasses).setResultsField(resultsField).setTokenizationUpdate(tokenizationUpdate);
     }
 
     @Override
@@ -143,17 +160,20 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         FillMaskConfigUpdate that = (FillMaskConfigUpdate) o;
-        return Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(resultsField, that.resultsField);
+        return Objects.equals(numTopClasses, that.numTopClasses)
+            && Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(tokenizationUpdate, that.tokenizationUpdate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(numTopClasses, resultsField);
+        return Objects.hash(numTopClasses, resultsField, tokenizationUpdate);
     }
 
     public static class Builder implements InferenceConfigUpdate.Builder<FillMaskConfigUpdate.Builder, FillMaskConfigUpdate> {
         private Integer numTopClasses;
         private String resultsField;
+        private TokenizationUpdate tokenizationUpdate;
 
         public FillMaskConfigUpdate.Builder setNumTopClasses(Integer numTopClasses) {
             this.numTopClasses = numTopClasses;
@@ -166,8 +186,13 @@ public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXConte
             return this;
         }
 
+        public FillMaskConfigUpdate.Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) {
+            this.tokenizationUpdate = tokenizationUpdate;
+            return this;
+        }
+
         public FillMaskConfigUpdate build() {
-            return new FillMaskConfigUpdate(this.numTopClasses, this.resultsField);
+            return new FillMaskConfigUpdate(this.numTopClasses, this.resultsField, this.tokenizationUpdate);
         }
     }
 }

+ 26 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdate.java

@@ -13,7 +13,6 @@ import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
 import java.io.IOException;
 import java.util.HashMap;
@@ -22,18 +21,20 @@ import java.util.Objects;
 import java.util.Optional;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 
-public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+public class NerConfigUpdate extends NlpConfigUpdate {
     public static final String NAME = NerConfig.NAME;
 
     public static NerConfigUpdate fromMap(Map<String, Object> map) {
         Map<String, Object> options = new HashMap<>(map);
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
+        TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options);
 
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
         }
-        return new NerConfigUpdate(resultsField);
+        return new NerConfigUpdate(resultsField, tokenizationUpdate);
     }
 
     private static final ObjectParser<NerConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
@@ -41,6 +42,11 @@ public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObj
     private static ObjectParser<NerConfigUpdate.Builder, Void> createParser(boolean lenient) {
         ObjectParser<NerConfigUpdate.Builder, Void> parser = new ObjectParser<>(NAME, lenient, NerConfigUpdate.Builder::new);
         parser.declareString(NerConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        parser.declareNamedObject(
+            NerConfigUpdate.Builder::setTokenizationUpdate,
+            (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, lenient),
+            TOKENIZATION
+        );
         return parser;
     }
 
@@ -50,26 +56,27 @@ public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObj
 
     private final String resultsField;
 
-    public NerConfigUpdate(String resultsField) {
+    public NerConfigUpdate(String resultsField, TokenizationUpdate tokenizationUpdate) {
+        super(tokenizationUpdate);
         this.resultsField = resultsField;
     }
 
     public NerConfigUpdate(StreamInput in) throws IOException {
+        super(in);
         this.resultsField = in.readOptionalString();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
         out.writeOptionalString(resultsField);
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
-        builder.endObject();
         return builder;
     }
 
@@ -99,14 +106,14 @@ public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObj
 
         return new NerConfig(
             nerConfig.getVocabularyConfig(),
-            nerConfig.getTokenization(),
+            (tokenizationUpdate == null) ? nerConfig.getTokenization() : tokenizationUpdate.apply(nerConfig.getTokenization()),
             nerConfig.getClassificationLabels(),
             Optional.ofNullable(resultsField).orElse(nerConfig.getResultsField())
         );
     }
 
     boolean isNoop(NerConfig originalConfig) {
-        return (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()));
+        return (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField())) && super.isNoop();
     }
 
     @Override
@@ -121,7 +128,7 @@ public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObj
 
     @Override
     public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
-        return new NerConfigUpdate.Builder().setResultsField(resultsField);
+        return new NerConfigUpdate.Builder().setResultsField(resultsField).setTokenizationUpdate(tokenizationUpdate);
     }
 
     @Override
@@ -129,16 +136,17 @@ public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObj
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         NerConfigUpdate that = (NerConfigUpdate) o;
-        return Objects.equals(resultsField, that.resultsField);
+        return Objects.equals(resultsField, that.resultsField) && Objects.equals(tokenizationUpdate, that.tokenizationUpdate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(resultsField);
+        return Objects.hash(resultsField, tokenizationUpdate);
     }
 
     public static class Builder implements InferenceConfigUpdate.Builder<NerConfigUpdate.Builder, NerConfigUpdate> {
         private String resultsField;
+        private TokenizationUpdate tokenizationUpdate;
 
         @Override
         public NerConfigUpdate.Builder setResultsField(String resultsField) {
@@ -146,8 +154,13 @@ public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObj
             return this;
         }
 
+        public NerConfigUpdate.Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) {
+            this.tokenizationUpdate = tokenizationUpdate;
+            return this;
+        }
+
         public NerConfigUpdate build() {
-            return new NerConfigUpdate(this.resultsField);
+            return new NerConfigUpdate(resultsField, tokenizationUpdate);
         }
     }
 }

+ 83 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java

@@ -7,6 +7,88 @@
 
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
-public abstract class NlpConfigUpdate implements InferenceConfigUpdate {
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
+import java.io.IOException;
+import java.util.Map;
+
+public abstract class NlpConfigUpdate implements InferenceConfigUpdate, NamedXContentObject {
+
+    @SuppressWarnings("unchecked")
+    public static TokenizationUpdate tokenizationFromMap(Map<String, Object> map) {
+        Map<String, Object> tokenziation = (Map<String, Object>) map.remove("tokenization");
+        if (tokenziation == null) {
+            return null;
+        }
+
+        Map<String, Object> bert = (Map<String, Object>) tokenziation.remove("bert");
+        if (bert == null && tokenziation.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("unknown tokenization type expecting one of [bert] got {}", tokenziation.keySet());
+        }
+        Object truncate = bert.remove("truncate");
+        if (truncate == null) {
+            return null;
+        }
+        return new BertTokenizationUpdate(Tokenization.Truncate.fromString(truncate.toString()));
+    }
+
+    protected final TokenizationUpdate tokenizationUpdate;
+
+    public NlpConfigUpdate(@Nullable TokenizationUpdate tokenizationUpdate) {
+        this.tokenizationUpdate = tokenizationUpdate;
+    }
+
+    public NlpConfigUpdate(StreamInput in) throws IOException {
+        if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
+            tokenizationUpdate = in.readOptionalNamedWriteable(TokenizationUpdate.class);
+        } else {
+            tokenizationUpdate = null;
+        }
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
+            out.writeOptionalNamedWriteable(tokenizationUpdate);
+        }
+    }
+
+    protected boolean isNoop() {
+        return tokenizationUpdate == null || tokenizationUpdate.isNoop();
+    }
+
+    public TokenizationUpdate getTokenizationUpdate() {
+        return tokenizationUpdate;
+    }
+
+    @Override
+    public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        if (tokenizationUpdate != null) {
+            NamedXContentObjectHelper.writeNamedObject(builder, params, NlpConfig.TOKENIZATION.getPreferredName(), tokenizationUpdate);
+        }
+        doXContentBody(builder, params);
+        builder.endObject();
+        return builder;
+    }
+
+    public abstract XContentBuilder doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException;
+
+    /**
+     * Required because this class implements 2 interfaces defining the
+     * method {@code String getName()} and the compiler insists it must
+     * be resolved here in the abstract class
+     */
+    @Override
+    public String getName() {
+        return InferenceConfigUpdate.super.getName();
+    }
 }

+ 31 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdate.java

@@ -21,6 +21,7 @@ import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 
 public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
     public static final String NAME = PassThroughConfig.NAME;
@@ -28,11 +29,12 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
     public static PassThroughConfigUpdate fromMap(Map<String, Object> map) {
         Map<String, Object> options = new HashMap<>(map);
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
+        TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options);
 
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
         }
-        return new PassThroughConfigUpdate(resultsField);
+        return new PassThroughConfigUpdate(resultsField, tokenizationUpdate);
     }
 
     private static final ObjectParser<PassThroughConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
@@ -44,6 +46,11 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
             PassThroughConfigUpdate.Builder::new
         );
         parser.declareString(PassThroughConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        parser.declareNamedObject(
+            PassThroughConfigUpdate.Builder::setTokenizationUpdate,
+            (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, lenient),
+            TOKENIZATION
+        );
         return parser;
     }
 
@@ -53,26 +60,27 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
 
     private final String resultsField;
 
-    public PassThroughConfigUpdate(String resultsField) {
+    public PassThroughConfigUpdate(String resultsField, TokenizationUpdate tokenizationUpdate) {
+        super(tokenizationUpdate);
         this.resultsField = resultsField;
     }
 
     public PassThroughConfigUpdate(StreamInput in) throws IOException {
+        super(in);
         this.resultsField = in.readOptionalString();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
         out.writeOptionalString(resultsField);
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
-        builder.endObject();
         return builder;
     }
 
@@ -88,7 +96,7 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
 
     @Override
     public InferenceConfig apply(InferenceConfig originalConfig) {
-        if (resultsField == null || resultsField.equals(originalConfig.getResultsField())) {
+        if ((resultsField == null || resultsField.equals(originalConfig.getResultsField())) && super.isNoop()) {
             return originalConfig;
         }
 
@@ -101,7 +109,13 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
         }
 
         PassThroughConfig passThroughConfig = (PassThroughConfig) originalConfig;
-        return new PassThroughConfig(passThroughConfig.getVocabularyConfig(), passThroughConfig.getTokenization(), resultsField);
+        return new PassThroughConfig(
+            passThroughConfig.getVocabularyConfig(),
+            (tokenizationUpdate == null)
+                ? passThroughConfig.getTokenization()
+                : tokenizationUpdate.apply(passThroughConfig.getTokenization()),
+            resultsField == null ? originalConfig.getResultsField() : resultsField
+        );
     }
 
     @Override
@@ -116,7 +130,7 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
 
     @Override
     public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
-        return new PassThroughConfigUpdate.Builder().setResultsField(resultsField);
+        return new PassThroughConfigUpdate.Builder().setResultsField(resultsField).setTokenizationUpdate(tokenizationUpdate);
     }
 
     @Override
@@ -124,16 +138,17 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         PassThroughConfigUpdate that = (PassThroughConfigUpdate) o;
-        return Objects.equals(resultsField, that.resultsField);
+        return Objects.equals(resultsField, that.resultsField) && Objects.equals(tokenizationUpdate, that.tokenizationUpdate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(resultsField);
+        return Objects.hash(resultsField, tokenizationUpdate);
     }
 
     public static class Builder implements InferenceConfigUpdate.Builder<PassThroughConfigUpdate.Builder, PassThroughConfigUpdate> {
         private String resultsField;
+        private TokenizationUpdate tokenizationUpdate;
 
         @Override
         public PassThroughConfigUpdate.Builder setResultsField(String resultsField) {
@@ -141,8 +156,13 @@ public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXCo
             return this;
         }
 
+        public PassThroughConfigUpdate.Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) {
+            this.tokenizationUpdate = tokenizationUpdate;
+            return this;
+        }
+
         public PassThroughConfigUpdate build() {
-            return new PassThroughConfigUpdate(this.resultsField);
+            return new PassThroughConfigUpdate(this.resultsField, tokenizationUpdate);
         }
     }
 }

+ 52 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdate.java

@@ -21,6 +21,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig.CLASSIFICATION_LABELS;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig.NUM_TOP_CLASSES;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig.RESULTS_FIELD;
@@ -35,11 +36,12 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         Integer numTopClasses = (Integer) options.remove(NUM_TOP_CLASSES.getPreferredName());
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
         List<String> classificationLabels = (List<String>) options.remove(CLASSIFICATION_LABELS.getPreferredName());
+        TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options);
 
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
         }
-        return new TextClassificationConfigUpdate(classificationLabels, numTopClasses, resultsField);
+        return new TextClassificationConfigUpdate(classificationLabels, numTopClasses, resultsField, tokenizationUpdate);
     }
 
     private static final ObjectParser<TextClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
@@ -49,6 +51,11 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         parser.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
         parser.declareString(Builder::setResultsField, RESULTS_FIELD);
         parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        parser.declareNamedObject(
+            Builder::setTokenizationUpdate,
+            (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, lenient),
+            TOKENIZATION
+        );
         return parser;
     }
 
@@ -60,13 +67,20 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
     private final Integer numTopClasses;
     private final String resultsField;
 
-    public TextClassificationConfigUpdate(List<String> classificationLabels, Integer numTopClasses, String resultsField) {
+    public TextClassificationConfigUpdate(
+        List<String> classificationLabels,
+        Integer numTopClasses,
+        String resultsField,
+        TokenizationUpdate tokenizationUpdate
+    ) {
+        super(tokenizationUpdate);
         this.classificationLabels = classificationLabels;
         this.numTopClasses = numTopClasses;
         this.resultsField = resultsField;
     }
 
     public TextClassificationConfigUpdate(StreamInput in) throws IOException {
+        super(in);
         classificationLabels = in.readOptionalStringList();
         numTopClasses = in.readOptionalVInt();
         resultsField = in.readOptionalString();
@@ -84,6 +98,7 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
         out.writeOptionalStringCollection(classificationLabels);
         out.writeOptionalVInt(numTopClasses);
         out.writeOptionalString(resultsField);
@@ -122,13 +137,19 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         if (resultsField != null) {
             builder.setResultsField(resultsField);
         }
+
+        if (tokenizationUpdate != null) {
+            builder.setTokenization(tokenizationUpdate.apply(classificationConfig.getTokenization()));
+        }
+
         return builder.build();
     }
 
     boolean isNoop(TextClassificationConfig originalConfig) {
         return (this.numTopClasses == null || this.numTopClasses == originalConfig.getNumTopClasses())
             && (this.classificationLabels == null)
-            && (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()));
+            && (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()))
+            && super.isNoop();
     }
 
     @Override
@@ -141,14 +162,24 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         return resultsField;
     }
 
+    public Integer getNumTopClasses() {
+        return numTopClasses;
+    }
+
+    public List<String> getClassificationLabels() {
+        return classificationLabels;
+    }
+
     @Override
     public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
-        return new Builder().setClassificationLabels(classificationLabels).setNumTopClasses(numTopClasses).setResultsField(resultsField);
+        return new Builder().setClassificationLabels(classificationLabels)
+            .setNumTopClasses(numTopClasses)
+            .setResultsField(resultsField)
+            .setTokenizationUpdate(tokenizationUpdate);
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         if (numTopClasses != null) {
             builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
         }
@@ -158,7 +189,6 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
-        builder.endObject();
         return builder;
     }
 
@@ -169,12 +199,13 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         TextClassificationConfigUpdate that = (TextClassificationConfigUpdate) o;
         return Objects.equals(classificationLabels, that.classificationLabels)
             && Objects.equals(numTopClasses, that.numTopClasses)
-            && Objects.equals(resultsField, that.resultsField);
+            && Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(tokenizationUpdate, that.tokenizationUpdate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(classificationLabels, numTopClasses, resultsField);
+        return Objects.hash(classificationLabels, numTopClasses, resultsField, tokenizationUpdate);
     }
 
     public static class Builder
@@ -183,6 +214,7 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
         private List<String> classificationLabels;
         private Integer numTopClasses;
         private String resultsField;
+        private TokenizationUpdate tokenizationUpdate;
 
         public TextClassificationConfigUpdate.Builder setNumTopClasses(Integer numTopClasses) {
             this.numTopClasses = numTopClasses;
@@ -200,8 +232,18 @@ public class TextClassificationConfigUpdate extends NlpConfigUpdate implements N
             return this;
         }
 
+        public TextClassificationConfigUpdate.Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) {
+            this.tokenizationUpdate = tokenizationUpdate;
+            return this;
+        }
+
         public TextClassificationConfigUpdate build() {
-            return new TextClassificationConfigUpdate(this.classificationLabels, this.numTopClasses, this.resultsField);
+            return new TextClassificationConfigUpdate(
+                this.classificationLabels,
+                this.numTopClasses,
+                this.resultsField,
+                this.tokenizationUpdate
+            );
         }
     }
 }

+ 29 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java

@@ -21,6 +21,7 @@ import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 
 public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
 
@@ -29,11 +30,12 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
     public static TextEmbeddingConfigUpdate fromMap(Map<String, Object> map) {
         Map<String, Object> options = new HashMap<>(map);
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
+        TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options);
 
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
         }
-        return new TextEmbeddingConfigUpdate(resultsField);
+        return new TextEmbeddingConfigUpdate(resultsField, tokenizationUpdate);
     }
 
     private static final ObjectParser<TextEmbeddingConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
@@ -45,6 +47,11 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
             TextEmbeddingConfigUpdate.Builder::new
         );
         parser.declareString(TextEmbeddingConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        parser.declareNamedObject(
+            TextEmbeddingConfigUpdate.Builder::setTokenizationUpdate,
+            (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, lenient),
+            TOKENIZATION
+        );
         return parser;
     }
 
@@ -54,26 +61,27 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
 
     private final String resultsField;
 
-    public TextEmbeddingConfigUpdate(String resultsField) {
+    public TextEmbeddingConfigUpdate(String resultsField, TokenizationUpdate tokenizationUpdate) {
+        super(tokenizationUpdate);
         this.resultsField = resultsField;
     }
 
     public TextEmbeddingConfigUpdate(StreamInput in) throws IOException {
+        super(in);
         this.resultsField = in.readOptionalString();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
         out.writeOptionalString(resultsField);
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
-        builder.endObject();
         return builder;
     }
 
@@ -89,7 +97,7 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
 
     @Override
     public InferenceConfig apply(InferenceConfig originalConfig) {
-        if (resultsField == null || resultsField.equals(originalConfig.getResultsField())) {
+        if ((resultsField == null || resultsField.equals(originalConfig.getResultsField())) && super.isNoop()) {
             return originalConfig;
         }
 
@@ -102,7 +110,11 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
         }
 
         TextEmbeddingConfig embeddingConfig = (TextEmbeddingConfig) originalConfig;
-        return new TextEmbeddingConfig(embeddingConfig.getVocabularyConfig(), embeddingConfig.getTokenization(), resultsField);
+        return new TextEmbeddingConfig(
+            embeddingConfig.getVocabularyConfig(),
+            tokenizationUpdate == null ? embeddingConfig.getTokenization() : tokenizationUpdate.apply(embeddingConfig.getTokenization()),
+            resultsField == null ? embeddingConfig.getResultsField() : resultsField
+        );
     }
 
     @Override
@@ -117,7 +129,7 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
 
     @Override
     public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
-        return new Builder().setResultsField(resultsField);
+        return new Builder().setResultsField(resultsField).setTokenizationUpdate(tokenizationUpdate);
     }
 
     @Override
@@ -125,16 +137,17 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         TextEmbeddingConfigUpdate that = (TextEmbeddingConfigUpdate) o;
-        return Objects.equals(resultsField, that.resultsField);
+        return Objects.equals(resultsField, that.resultsField) && Objects.equals(tokenizationUpdate, that.tokenizationUpdate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(resultsField);
+        return Objects.hash(resultsField, tokenizationUpdate);
     }
 
     public static class Builder implements InferenceConfigUpdate.Builder<TextEmbeddingConfigUpdate.Builder, TextEmbeddingConfigUpdate> {
         private String resultsField;
+        private TokenizationUpdate tokenizationUpdate;
 
         @Override
         public Builder setResultsField(String resultsField) {
@@ -142,8 +155,13 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
             return this;
         }
 
+        public TextEmbeddingConfigUpdate.Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) {
+            this.tokenizationUpdate = tokenizationUpdate;
+            return this;
+        }
+
         public TextEmbeddingConfigUpdate build() {
-            return new TextEmbeddingConfigUpdate(this.resultsField);
+            return new TextEmbeddingConfigUpdate(resultsField, tokenizationUpdate);
         }
     }
 }

+ 21 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TokenizationUpdate.java

@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+public interface TokenizationUpdate extends NamedXContentObject, NamedWriteable {
+
+    /**
+     * @return True if applying does not modify
+     */
+    boolean isNoop();
+
+    Tokenization apply(Tokenization originalConfig);
+}

+ 42 - 18
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java

@@ -10,7 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -24,6 +24,7 @@ import java.util.Objects;
 import java.util.Optional;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.LABELS;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.MULTI_LABEL;
 
@@ -32,7 +33,7 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
     public static final String NAME = "zero_shot_classification";
 
     public static ZeroShotClassificationConfigUpdate fromXContentStrict(XContentParser parser) {
-        return STRICT_PARSER.apply(parser, null);
+        return STRICT_PARSER.apply(parser, null).build();
     }
 
     @SuppressWarnings({ "unchecked" })
@@ -41,22 +42,28 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         Boolean isMultiLabel = (Boolean) options.remove(MULTI_LABEL.getPreferredName());
         List<String> labels = (List<String>) options.remove(LABELS.getPreferredName());
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());
+        TokenizationUpdate tokenizationUpdate = NlpConfigUpdate.tokenizationFromMap(options);
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
         }
-        return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel, resultsField);
+        return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel, resultsField, tokenizationUpdate);
     }
 
     @SuppressWarnings({ "unchecked" })
-    private static final ConstructingObjectParser<ZeroShotClassificationConfigUpdate, Void> STRICT_PARSER = new ConstructingObjectParser<>(
+    private static final ObjectParser<ZeroShotClassificationConfigUpdate.Builder, Void> STRICT_PARSER = new ObjectParser<>(
         NAME,
-        a -> new ZeroShotClassificationConfigUpdate((List<String>) a[0], (Boolean) a[1], (String) a[2])
+        ZeroShotClassificationConfigUpdate.Builder::new
     );
 
     static {
-        STRICT_PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
-        STRICT_PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
-        STRICT_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
+        STRICT_PARSER.declareStringArray(Builder::setLabels, LABELS);
+        STRICT_PARSER.declareBoolean(Builder::setMultiLabel, MULTI_LABEL);
+        STRICT_PARSER.declareString(Builder::setResultsField, RESULTS_FIELD);
+        STRICT_PARSER.declareNamedObject(
+            Builder::setTokenizationUpdate,
+            (p, c, n) -> p.namedObject(TokenizationUpdate.class, n, false),
+            TOKENIZATION
+        );
     }
 
     private final List<String> labels;
@@ -66,8 +73,10 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
     public ZeroShotClassificationConfigUpdate(
         @Nullable List<String> labels,
         @Nullable Boolean isMultiLabel,
-        @Nullable String resultsField
+        @Nullable String resultsField,
+        @Nullable TokenizationUpdate tokenizationUpdate
     ) {
+        super(tokenizationUpdate);
         this.labels = labels;
         if (labels != null && labels.isEmpty()) {
             throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
@@ -77,6 +86,7 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
     }
 
     public ZeroShotClassificationConfigUpdate(StreamInput in) throws IOException {
+        super(in);
         labels = in.readOptionalStringList();
         isMultiLabel = in.readOptionalBoolean();
         resultsField = in.readOptionalString();
@@ -84,14 +94,14 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
         out.writeOptionalStringCollection(labels);
         out.writeOptionalBoolean(isMultiLabel);
         out.writeOptionalString(resultsField);
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         if (labels != null) {
             builder.field(LABELS.getPreferredName(), labels);
         }
@@ -101,7 +111,6 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
-        builder.endObject();
         return builder;
     }
 
@@ -134,7 +143,7 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         return new ZeroShotClassificationConfig(
             zeroShotConfig.getClassificationLabels(),
             zeroShotConfig.getVocabularyConfig(),
-            zeroShotConfig.getTokenization(),
+            tokenizationUpdate == null ? zeroShotConfig.getTokenization() : tokenizationUpdate.apply(zeroShotConfig.getTokenization()),
             zeroShotConfig.getHypothesisTemplate(),
             Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()),
             Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels()),
@@ -145,7 +154,8 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
     boolean isNoop(ZeroShotClassificationConfig originalConfig) {
         return (labels == null || labels.equals(originalConfig.getClassificationLabels()))
             && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()))
-            && (resultsField == null || resultsField.equals(originalConfig.getResultsField()));
+            && (resultsField == null || resultsField.equals(originalConfig.getResultsField()))
+            && super.isNoop();
     }
 
     @Override
@@ -160,7 +170,10 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
 
     @Override
     public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
-        return new Builder().setLabels(labels).setMultiLabel(isMultiLabel);
+        return new Builder().setLabels(labels)
+            .setMultiLabel(isMultiLabel)
+            .setResultsField(resultsField)
+            .setTokenizationUpdate(tokenizationUpdate);
     }
 
     @Override
@@ -176,24 +189,30 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         ZeroShotClassificationConfigUpdate that = (ZeroShotClassificationConfigUpdate) o;
         return Objects.equals(isMultiLabel, that.isMultiLabel)
             && Objects.equals(labels, that.labels)
-            && Objects.equals(resultsField, that.resultsField);
+            && Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(tokenizationUpdate, that.tokenizationUpdate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(labels, isMultiLabel, resultsField);
+        return Objects.hash(labels, isMultiLabel, resultsField, tokenizationUpdate);
     }
 
     public List<String> getLabels() {
         return labels;
     }
 
+    public Boolean getMultiLabel() {
+        return isMultiLabel;
+    }
+
     public static class Builder
         implements
             InferenceConfigUpdate.Builder<ZeroShotClassificationConfigUpdate.Builder, ZeroShotClassificationConfigUpdate> {
         private List<String> labels;
         private Boolean isMultiLabel;
         private String resultsField;
+        private TokenizationUpdate tokenizationUpdate;
 
         @Override
         public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) {
@@ -211,8 +230,13 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
             return this;
         }
 
+        public Builder setTokenizationUpdate(TokenizationUpdate tokenizationUpdate) {
+            this.tokenizationUpdate = tokenizationUpdate;
+            return this;
+        }
+
         public ZeroShotClassificationConfigUpdate build() {
-            return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel, resultsField);
+            return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel, resultsField, tokenizationUpdate);
         }
     }
 }

+ 47 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java

@@ -9,9 +9,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -23,14 +26,20 @@ import static org.hamcrest.Matchers.equalTo;
 public class FillMaskConfigUpdateTests extends AbstractBWCSerializationTestCase<FillMaskConfigUpdate> {
 
     public void testFromMap() {
-        FillMaskConfigUpdate expected = new FillMaskConfigUpdate(3, "ml-results");
+        FillMaskConfigUpdate expected = new FillMaskConfigUpdate(3, "ml-results", new BertTokenizationUpdate(Tokenization.Truncate.FIRST));
         Map<String, Object> config = new HashMap<>() {
             {
                 put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
                 put(NlpConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
             }
         };
-        assertThat(FillMaskConfigUpdate.fromMap(config), equalTo(expected));
+        var pp = FillMaskConfigUpdate.fromMap(config);
+        assertThat(pp, equalTo(expected));
     }
 
     public void testFromMapWithUnknownField() {
@@ -50,6 +59,12 @@ public class FillMaskConfigUpdateTests extends AbstractBWCSerializationTestCase<
                 .isNoop(new FillMaskConfig.Builder().setResultsField("bar").build())
         );
 
+        assertFalse(
+            new FillMaskConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(Tokenization.Truncate.SECOND))
+                .build()
+                .isNoop(new FillMaskConfig.Builder().setResultsField("bar").build())
+        );
+
         assertTrue(
             new FillMaskConfigUpdate.Builder().setNumTopClasses(3).build().isNoop(new FillMaskConfig.Builder().setNumTopClasses(3).build())
         );
@@ -70,6 +85,20 @@ public class FillMaskConfigUpdateTests extends AbstractBWCSerializationTestCase<
                 new FillMaskConfigUpdate.Builder().setNumTopClasses(originalConfig.getNumTopClasses() + 1).build().apply(originalConfig)
             )
         );
+
+        Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
+        Tokenization tokenization = new BertTokenization(
+            originalConfig.getTokenization().doLowerCase(),
+            originalConfig.getTokenization().withSpecialTokens(),
+            originalConfig.getTokenization().maxSequenceLength(),
+            truncate
+        );
+        assertThat(
+            new FillMaskConfig.Builder(originalConfig).setTokenization(tokenization).build(),
+            equalTo(
+                new FillMaskConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate)).build().apply(originalConfig)
+            )
+        );
     }
 
     @Override
@@ -91,11 +120,27 @@ public class FillMaskConfigUpdateTests extends AbstractBWCSerializationTestCase<
         if (randomBoolean()) {
             builder.setResultsField(randomAlphaOfLength(8));
         }
+        if (randomBoolean()) {
+            builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values())));
+        }
         return builder.build();
     }
 
     @Override
     protected FillMaskConfigUpdate mutateInstanceForVersion(FillMaskConfigUpdate instance, Version version) {
+        if (version.before(Version.V_8_1_0)) {
+            return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null);
+        }
         return instance;
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
 }

+ 42 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java

@@ -9,9 +9,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -24,10 +27,15 @@ import static org.hamcrest.Matchers.sameInstance;
 public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerConfigUpdate> {
 
     public void testFromMap() {
-        NerConfigUpdate expected = new NerConfigUpdate("ml-results");
+        NerConfigUpdate expected = new NerConfigUpdate("ml-results", new BertTokenizationUpdate(Tokenization.Truncate.FIRST));
         Map<String, Object> config = new HashMap<>() {
             {
                 put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
             }
         };
         assertThat(NerConfigUpdate.fromMap(config), equalTo(expected));
@@ -55,6 +63,23 @@ public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerCo
             ),
             equalTo(new NerConfigUpdate.Builder().setResultsField("ml-results").build().apply(originalConfig))
         );
+
+        Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
+        Tokenization tokenization = new BertTokenization(
+            originalConfig.getTokenization().doLowerCase(),
+            originalConfig.getTokenization().withSpecialTokens(),
+            originalConfig.getTokenization().maxSequenceLength(),
+            truncate
+        );
+        assertThat(
+            new NerConfig(
+                originalConfig.getVocabularyConfig(),
+                tokenization,
+                originalConfig.getClassificationLabels(),
+                originalConfig.getResultsField()
+            ),
+            equalTo(new NerConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate)).build().apply(originalConfig))
+        );
     }
 
     @Override
@@ -73,11 +98,27 @@ public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerCo
         if (randomBoolean()) {
             builder.setResultsField(randomAlphaOfLength(8));
         }
+        if (randomBoolean()) {
+            builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values())));
+        }
         return builder.build();
     }
 
     @Override
     protected NerConfigUpdate mutateInstanceForVersion(NerConfigUpdate instance, Version version) {
+        if (version.before(Version.V_8_1_0)) {
+            return new NerConfigUpdate(instance.getResultsField(), null);
+        }
         return instance;
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
 }

+ 66 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdateTests.java

@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+
+public class NlpConfigUpdateTests extends ESTestCase {
+
+    public void testTokenizationFromMap() {
+
+        Map<String, Object> config = new HashMap<>() {
+            {
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
+            }
+        };
+        assertThat(NlpConfigUpdate.tokenizationFromMap(config), equalTo(new BertTokenizationUpdate(Tokenization.Truncate.FIRST)));
+
+        config = new HashMap<>();
+        assertThat(NlpConfigUpdate.tokenizationFromMap(config), nullValue());
+
+        config = new HashMap<>() {
+            {
+                Map<String, Object> truncate = new HashMap<>();
+                // only the truncate option is updatable
+                truncate.put("do_lower_case", true);
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
+            }
+        };
+        assertThat(NlpConfigUpdate.tokenizationFromMap(config), nullValue());
+
+        Map<String, Object> finalConfig = new HashMap<>() {
+            {
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("not_bert", truncate);
+                put("tokenization", bert);
+            }
+        };
+        ElasticsearchStatusException e = expectThrows(
+            ElasticsearchStatusException.class,
+            () -> NlpConfigUpdate.tokenizationFromMap(finalConfig)
+        );
+        assertThat(e.getMessage(), containsString("unknown tokenization type expecting one of [bert] got [not_bert]"));
+
+    }
+}

+ 44 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java

@@ -9,9 +9,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -24,10 +27,18 @@ import static org.hamcrest.Matchers.sameInstance;
 public class PassThroughConfigUpdateTests extends AbstractBWCSerializationTestCase<PassThroughConfigUpdate> {
 
     public void testFromMap() {
-        PassThroughConfigUpdate expected = new PassThroughConfigUpdate("ml-results");
+        PassThroughConfigUpdate expected = new PassThroughConfigUpdate(
+            "ml-results",
+            new BertTokenizationUpdate(Tokenization.Truncate.FIRST)
+        );
         Map<String, Object> config = new HashMap<>() {
             {
                 put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
             }
         };
         assertThat(PassThroughConfigUpdate.fromMap(config), equalTo(expected));
@@ -50,6 +61,22 @@ public class PassThroughConfigUpdateTests extends AbstractBWCSerializationTestCa
             new PassThroughConfig(originalConfig.getVocabularyConfig(), originalConfig.getTokenization(), "ml-results"),
             equalTo(new PassThroughConfigUpdate.Builder().setResultsField("ml-results").build().apply(originalConfig))
         );
+
+        Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
+        Tokenization tokenization = new BertTokenization(
+            originalConfig.getTokenization().doLowerCase(),
+            originalConfig.getTokenization().withSpecialTokens(),
+            originalConfig.getTokenization().maxSequenceLength(),
+            truncate
+        );
+        assertThat(
+            new PassThroughConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
+            equalTo(
+                new PassThroughConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
+                    .build()
+                    .apply(originalConfig)
+            )
+        );
     }
 
     @Override
@@ -68,11 +95,27 @@ public class PassThroughConfigUpdateTests extends AbstractBWCSerializationTestCa
         if (randomBoolean()) {
             builder.setResultsField(randomAlphaOfLength(8));
         }
+        if (randomBoolean()) {
+            builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values())));
+        }
         return builder.build();
     }
 
     @Override
     protected PassThroughConfigUpdate mutateInstanceForVersion(PassThroughConfigUpdate instance, Version version) {
+        if (version.before(Version.V_8_1_0)) {
+            return new PassThroughConfigUpdate(instance.getResultsField(), null);
+        }
         return instance;
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
 }

+ 58 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java

@@ -10,9 +10,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -26,12 +29,22 @@ import static org.hamcrest.Matchers.equalTo;
 public class TextClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<TextClassificationConfigUpdate> {
 
     public void testFromMap() {
-        TextClassificationConfigUpdate expected = new TextClassificationConfigUpdate(List.of("foo", "bar"), 3, "ml-results");
+        TextClassificationConfigUpdate expected = new TextClassificationConfigUpdate(
+            List.of("foo", "bar"),
+            3,
+            "ml-results",
+            new BertTokenizationUpdate(Tokenization.Truncate.FIRST)
+        );
         Map<String, Object> config = new HashMap<>() {
             {
                 put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
                 put(NlpConfig.CLASSIFICATION_LABELS.getPreferredName(), List.of("foo", "bar"));
                 put(NlpConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
             }
         };
         assertThat(TextClassificationConfigUpdate.fromMap(config), equalTo(expected));
@@ -67,8 +80,14 @@ public class TextClassificationConfigUpdateTests extends AbstractBWCSerializatio
         assertFalse(
             new TextClassificationConfigUpdate.Builder().setClassificationLabels(List.of("a", "b"))
                 .build()
-                .isNoop(new TextClassificationConfig.Builder().setClassificationLabels(List.of("c", "d")).setNumTopClasses(3).build())
+                .isNoop(new TextClassificationConfig.Builder().setClassificationLabels(List.of("c", "d")).build())
+        );
+        assertFalse(
+            new TextClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(Tokenization.Truncate.SECOND))
+                .build()
+                .isNoop(new TextClassificationConfig.Builder().setClassificationLabels(List.of("c", "d")).build())
         );
+
     }
 
     public void testApply() {
@@ -100,6 +119,22 @@ public class TextClassificationConfigUpdateTests extends AbstractBWCSerializatio
                     .apply(originalConfig)
             )
         );
+
+        Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
+        Tokenization tokenization = new BertTokenization(
+            originalConfig.getTokenization().doLowerCase(),
+            originalConfig.getTokenization().withSpecialTokens(),
+            originalConfig.getTokenization().maxSequenceLength(),
+            truncate
+        );
+        assertThat(
+            new TextClassificationConfig.Builder(originalConfig).setTokenization(tokenization).build(),
+            equalTo(
+                new TextClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
+                    .build()
+                    .apply(originalConfig)
+            )
+        );
     }
 
     public void testApplyWithInvalidLabels() {
@@ -145,11 +180,32 @@ public class TextClassificationConfigUpdateTests extends AbstractBWCSerializatio
         if (randomBoolean()) {
             builder.setResultsField(randomAlphaOfLength(8));
         }
+        if (randomBoolean()) {
+            builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values())));
+        }
         return builder.build();
     }
 
     @Override
     protected TextClassificationConfigUpdate mutateInstanceForVersion(TextClassificationConfigUpdate instance, Version version) {
+        if (version.before(Version.V_8_1_0)) {
+            return new TextClassificationConfigUpdate(
+                instance.getClassificationLabels(),
+                instance.getNumTopClasses(),
+                instance.getResultsField(),
+                null
+            );
+        }
         return instance;
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
 }

+ 44 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java

@@ -9,9 +9,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -24,10 +27,18 @@ import static org.hamcrest.Matchers.sameInstance;
 public class TextEmbeddingConfigUpdateTests extends AbstractBWCSerializationTestCase<TextEmbeddingConfigUpdate> {
 
     public void testFromMap() {
-        TextEmbeddingConfigUpdate expected = new TextEmbeddingConfigUpdate("ml-results");
+        TextEmbeddingConfigUpdate expected = new TextEmbeddingConfigUpdate(
+            "ml-results",
+            new BertTokenizationUpdate(Tokenization.Truncate.FIRST)
+        );
         Map<String, Object> config = new HashMap<>() {
             {
                 put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
             }
         };
         assertThat(TextEmbeddingConfigUpdate.fromMap(config), equalTo(expected));
@@ -50,6 +61,22 @@ public class TextEmbeddingConfigUpdateTests extends AbstractBWCSerializationTest
             new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), originalConfig.getTokenization(), "ml-results"),
             equalTo(new TextEmbeddingConfigUpdate.Builder().setResultsField("ml-results").build().apply(originalConfig))
         );
+
+        Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
+        Tokenization tokenization = new BertTokenization(
+            originalConfig.getTokenization().doLowerCase(),
+            originalConfig.getTokenization().withSpecialTokens(),
+            originalConfig.getTokenization().maxSequenceLength(),
+            truncate
+        );
+        assertThat(
+            new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
+            equalTo(
+                new TextEmbeddingConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
+                    .build()
+                    .apply(originalConfig)
+            )
+        );
     }
 
     @Override
@@ -68,11 +95,27 @@ public class TextEmbeddingConfigUpdateTests extends AbstractBWCSerializationTest
         if (randomBoolean()) {
             builder.setResultsField(randomAlphaOfLength(8));
         }
+        if (randomBoolean()) {
+            builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values())));
+        }
         return builder.build();
     }
 
     @Override
     protected TextEmbeddingConfigUpdate mutateInstanceForVersion(TextEmbeddingConfigUpdate instance, Version version) {
+        if (version.before(Version.V_8_1_0)) {
+            return new TextEmbeddingConfigUpdate(instance.getResultsField(), null);
+        }
         return instance;
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
 }

+ 58 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

@@ -9,9 +9,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -46,16 +49,30 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
 
     @Override
     protected ZeroShotClassificationConfigUpdate mutateInstanceForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
+        if (version.before(Version.V_8_1_0)) {
+            return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null);
+        }
         return instance;
     }
 
     public void testFromMap() {
-        ZeroShotClassificationConfigUpdate expected = new ZeroShotClassificationConfigUpdate(List.of("foo", "bar"), false, "ml-results");
+        ZeroShotClassificationConfigUpdate expected = new ZeroShotClassificationConfigUpdate(
+            List.of("foo", "bar"),
+            false,
+            "ml-results",
+            new BertTokenizationUpdate(Tokenization.Truncate.FIRST)
+        );
+
         Map<String, Object> config = new HashMap<>() {
             {
                 put(ZeroShotClassificationConfig.LABELS.getPreferredName(), List.of("foo", "bar"));
                 put(ZeroShotClassificationConfig.MULTI_LABEL.getPreferredName(), false);
                 put(ZeroShotClassificationConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+                Map<String, Object> truncate = new HashMap<>();
+                truncate.put("truncate", "first");
+                Map<String, Object> bert = new HashMap<>();
+                bert.put("bert", truncate);
+                put("tokenization", bert);
             }
         };
         assertThat(ZeroShotClassificationConfigUpdate.fromMap(config), equalTo(expected));
@@ -118,6 +135,30 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
             ),
             equalTo(new ZeroShotClassificationConfigUpdate.Builder().setResultsField("updated-field").build().apply(originalConfig))
         );
+
+        Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
+        Tokenization tokenization = new BertTokenization(
+            originalConfig.getTokenization().doLowerCase(),
+            originalConfig.getTokenization().withSpecialTokens(),
+            originalConfig.getTokenization().maxSequenceLength(),
+            truncate
+        );
+        assertThat(
+            new ZeroShotClassificationConfig(
+                originalConfig.getClassificationLabels(),
+                originalConfig.getVocabularyConfig(),
+                tokenization,
+                originalConfig.getHypothesisTemplate(),
+                originalConfig.isMultiLabel(),
+                originalConfig.getLabels(),
+                originalConfig.getResultsField()
+            ),
+            equalTo(
+                new ZeroShotClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
+                    .build()
+                    .apply(originalConfig)
+            )
+        );
     }
 
     public void testApplyWithEmptyLabelsInConfigAndUpdate() {
@@ -138,11 +179,26 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
         );
     }
 
+    public void testIsNoop() {
+        assertTrue(new ZeroShotClassificationConfigUpdate.Builder().build().isNoop(ZeroShotClassificationConfigTests.createRandom()));
+    }
+
     public static ZeroShotClassificationConfigUpdate createRandom() {
         return new ZeroShotClassificationConfigUpdate(
             randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomBoolean(),
-            randomBoolean() ? null : randomAlphaOfLength(5)
+            randomBoolean() ? null : randomAlphaOfLength(5),
+            randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()))
         );
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
 }

+ 50 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -484,6 +484,56 @@ public class PyTorchModelIT extends ESRestTestCase {
         );
     }
 
+    public void testTruncation() throws IOException {
+        String modelId = "no-truncation";
+
+        Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
+        request.setJsonEntity(
+            "{  "
+                + "    \"description\": \"simple model for testing\",\n"
+                + "    \"model_type\": \"pytorch\",\n"
+                + "    \"inference_config\": {\n"
+                + "        \"pass_through\": {\n"
+                + "            \"tokenization\": {"
+                + "              \"bert\": {"
+                + "                \"with_special_tokens\": false,"
+                + "                \"truncate\": \"none\","
+                + "                \"max_sequence_length\": 2"
+                + "              }\n"
+                + "            }\n"
+                + "        }\n"
+                + "    }\n"
+                + "}"
+        );
+        client().performRequest(request);
+
+        putVocabulary(List.of("once", "twice", "thrice"), modelId);
+        putModelDefinition(modelId);
+        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
+
+        String input = "once twice thrice";
+        ResponseException ex = expectThrows(ResponseException.class, () -> infer("once twice thrice", modelId));
+        assertThat(
+            ex.getMessage(),
+            containsString("Input too large. The tokenized input length [3] exceeds the maximum sequence length [2]")
+        );
+
+        request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer");
+        request.setJsonEntity(
+            "{"
+                + "\"docs\": [{\"input\":\""
+                + input
+                + "\"}],"
+                + "\"inference_config\": { "
+                + "  \"pass_through\": {"
+                + "    \"tokenization\": {\"bert\": {\"truncate\": \"first\"}}"
+                + "    }"
+                + "  }"
+                + "}"
+        );
+        client().performRequest(request);
+    }
+
     private int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
         int inferenceCount = 0;
         for (var node : nodes) {

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -346,7 +346,9 @@ public class DeploymentManager {
                 NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
                 processor.validateInputs(text);
                 assert config instanceof NlpConfig;
-                NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr);
+                NlpConfig nlpConfig = (NlpConfig) config;
+                NlpTask.Request request = processor.getRequestBuilder(nlpConfig)
+                    .buildRequest(text, requestIdStr, nlpConfig.getTokenization().getTruncate());
                 logger.debug(() -> "Inference Request " + request.processInput.utf8ToString());
                 if (request.tokenization.anyTruncated()) {
                     logger.debug("[{}] [{}] input truncated", modelId, requestId);

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
@@ -32,13 +33,13 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     }
 
     @Override
-    public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
+    public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
         if (tokenizer.getPadToken().isEmpty()) {
             throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
         }
 
         TokenizationResult tokenization = tokenizer.buildTokenizationResult(
-            inputs.stream().map(tokenizer::tokenize).collect(Collectors.toList())
+            inputs.stream().map(s -> tokenizer.tokenize(s, truncate)).collect(Collectors.toList())
         );
         return buildRequest(tokenization, requestId);
     }

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -54,7 +55,7 @@ public class NlpTask {
             int apply(TokenizationResult.Tokenization tokenization, int index);
         }
 
-        Request buildRequest(List<String> inputs, String requestId) throws IOException;
+        Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException;
 
         Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException;
 

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInfere
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -111,13 +112,13 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
         }
 
         @Override
-        public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
+        public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
             if (inputs.size() > 1) {
                 throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time");
             }
             List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
             for (String label : labels) {
-                tokenizations.add(tokenizer.tokenize(inputs.get(0), LoggerMessageFormat.format(null, hypothesisTemplate, label)));
+                tokenizations.add(tokenizer.tokenize(inputs.get(0), LoggerMessageFormat.format(null, hypothesisTemplate, label), truncate));
             }
             TokenizationResult result = tokenizer.buildTokenizationResult(tokenizations);
             return buildRequest(result, requestId);

+ 2 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -52,7 +52,6 @@ public class BertTokenizer implements NlpTokenizer {
     private final boolean doTokenizeCjKChars;
     private final boolean doStripAccents;
     private final boolean withSpecialTokens;
-    private final Tokenization.Truncate truncate;
     private final Set<String> neverSplit;
     private final int maxSequenceLength;
     private final NlpTask.RequestBuilder requestBuilder;
@@ -64,7 +63,6 @@ public class BertTokenizer implements NlpTokenizer {
         boolean doTokenizeCjKChars,
         boolean doStripAccents,
         boolean withSpecialTokens,
-        Tokenization.Truncate truncate,
         int maxSequenceLength,
         Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
         Set<String> neverSplit
@@ -76,7 +74,6 @@ public class BertTokenizer implements NlpTokenizer {
         this.doTokenizeCjKChars = doTokenizeCjKChars;
         this.doStripAccents = doStripAccents;
         this.withSpecialTokens = withSpecialTokens;
-        this.truncate = truncate;
         this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
         this.maxSequenceLength = maxSequenceLength;
         this.requestBuilder = requestBuilderFactory.apply(this);
@@ -113,7 +110,7 @@ public class BertTokenizer implements NlpTokenizer {
      * @return A {@link Tokenization}
      */
     @Override
-    public TokenizationResult.Tokenization tokenize(String seq) {
+    public TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncate truncate) {
         var innerResult = innerTokenize(seq);
         List<WordPieceTokenizer.TokenAndId> wordPieceTokens = innerResult.v1();
         List<Integer> tokenPositionMap = innerResult.v2();
@@ -164,7 +161,7 @@ public class BertTokenizer implements NlpTokenizer {
     }
 
     @Override
-    public TokenizationResult.Tokenization tokenize(String seq1, String seq2) {
+    public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate) {
         var innerResult = innerTokenize(seq1);
         List<WordPieceTokenizer.TokenAndId> wordPieceTokenSeq1s = innerResult.v1();
         List<Integer> tokenPositionMapSeq1 = innerResult.v2();
@@ -295,7 +292,6 @@ public class BertTokenizer implements NlpTokenizer {
         protected boolean doLowerCase = false;
         protected boolean doTokenizeCjKChars = true;
         protected boolean withSpecialTokens = true;
-        protected Tokenization.Truncate truncate = Tokenization.Truncate.FIRST;
         protected int maxSequenceLength;
         protected Boolean doStripAccents = null;
         protected Set<String> neverSplit;
@@ -307,7 +303,6 @@ public class BertTokenizer implements NlpTokenizer {
             this.doLowerCase = tokenization.doLowerCase();
             this.withSpecialTokens = tokenization.withSpecialTokens();
             this.maxSequenceLength = tokenization.maxSequenceLength();
-            this.truncate = tokenization.getTruncate();
         }
 
         private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
@@ -358,11 +353,6 @@ public class BertTokenizer implements NlpTokenizer {
             return this;
         }
 
-        public Builder setTruncate(Tokenization.Truncate truncate) {
-            this.truncate = truncate;
-            return this;
-        }
-
         public BertTokenizer build() {
             // if not set strip accents defaults to the value of doLowerCase
             if (doStripAccents == null) {
@@ -380,7 +370,6 @@ public class BertTokenizer implements NlpTokenizer {
                 doTokenizeCjKChars,
                 doStripAccents,
                 withSpecialTokens,
-                truncate,
                 maxSequenceLength,
                 requestBuilderFactory,
                 neverSplit

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -24,9 +24,9 @@ public interface NlpTokenizer {
 
     TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations);
 
-    TokenizationResult.Tokenization tokenize(String seq);
+    TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncate truncate);
 
-    TokenizationResult.Tokenization tokenize(String seq1, String seq2);
+    TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate);
 
     NlpTask.RequestBuilder requestBuilder();
 

+ 12 - 7
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java

@@ -29,11 +29,11 @@ public class BertRequestBuilderTests extends ESTestCase {
     public void testBuildRequest() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
             Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
-            new BertTokenization(null, null, 512, Tokenization.Truncate.NONE)
+            new BertTokenization(null, null, 512, null)
         ).build();
 
         BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
-        NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1");
+        NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
@@ -57,7 +57,7 @@ public class BertRequestBuilderTests extends ESTestCase {
     public void testInputTooLarge() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
             Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
-            new BertTokenization(null, null, 5, Tokenization.Truncate.NONE)
+            new BertTokenization(null, null, 5, null)
         ).build();
         {
             BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
@@ -65,7 +65,8 @@ public class BertRequestBuilderTests extends ESTestCase {
                 ElasticsearchStatusException.class,
                 () -> requestBuilder.buildRequest(
                     Collections.singletonList("Elasticsearch fun Elasticsearch fun Elasticsearch fun"),
-                    "request1"
+                    "request1",
+                    Tokenization.Truncate.NONE
                 )
             );
 
@@ -78,7 +79,7 @@ public class BertRequestBuilderTests extends ESTestCase {
             BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
             // input will become 3 tokens + the Class and Separator token = 5 which is
             // our max sequence length
-            requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1");
+            requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
         }
     }
 
@@ -101,11 +102,15 @@ public class BertRequestBuilderTests extends ESTestCase {
                 "God",
                 "##zilla"
             ),
-            new BertTokenization(null, null, 512, Tokenization.Truncate.NONE)
+            new BertTokenization(null, null, 512, null)
         ).build();
 
         BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
-        NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch", "my little red car", "Godzilla day"), "request1");
+        NlpTask.Request request = requestBuilder.buildRequest(
+            List.of("Elasticsearch", "my little red car", "Godzilla day"),
+            "request1",
+            Tokenization.Truncate.NONE
+        );
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -259,6 +259,6 @@ public class NerProcessorTests extends ESTestCase {
             .setDoLowerCase(true)
             .setWithSpecialTokens(false)
             .build();
-        return tokenizer.buildTokenizationResult(List.of(tokenizer.tokenize(input)));
+        return tokenizer.buildTokenizationResult(List.of(tokenizer.tokenize(input, Tokenization.Truncate.NONE)));
     }
 }

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -87,7 +87,8 @@ public class TextClassificationProcessorTests extends ESTestCase {
 
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
 
-        NlpTask.Request request = processor.getRequestBuilder(config).buildRequest(List.of("Elasticsearch fun"), "request1");
+        NlpTask.Request request = processor.getRequestBuilder(config)
+            .buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java

@@ -66,7 +66,7 @@ public class ZeroShotClassificationProcessorTests extends ESTestCase {
 
         NlpTask.Request request = processor.getRequestBuilder(
             (NlpConfig) new ZeroShotClassificationConfigUpdate.Builder().setLabels(List.of("new", "stuff")).build().apply(config)
-        ).buildRequest(List.of("Elasticsearch fun"), "request1");
+        ).buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 

+ 39 - 22
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -50,7 +50,7 @@ public class BertTokenizerTests extends ESTestCase {
             new BertTokenization(null, false, null, Tokenization.Truncate.NONE)
         ).build();
 
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
         assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun"));
         assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds());
         assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
@@ -62,7 +62,7 @@ public class BertTokenizerTests extends ESTestCase {
 
         ElasticsearchStatusException ex = expectThrows(
             ElasticsearchStatusException.class,
-            () -> tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla")
+            () -> tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.NONE)
         );
         assertThat(ex.getMessage(), equalTo("Input too large. The tokenized input length [8] exceeds the maximum sequence length [5]"));
 
@@ -72,28 +72,34 @@ public class BertTokenizerTests extends ESTestCase {
         ).build();
 
         // Shouldn't throw
-        tokenizer.tokenize("Elasticsearch fun with Pancake");
+        tokenizer.tokenize("Elasticsearch fun with Pancake", Tokenization.Truncate.NONE);
 
         // Should throw as special chars add two tokens
-        expectThrows(ElasticsearchStatusException.class, () -> specialCharTokenizer.tokenize("Elasticsearch fun with Pancake"));
+        expectThrows(
+            ElasticsearchStatusException.class,
+            () -> specialCharTokenizer.tokenize("Elasticsearch fun with Pancake", Tokenization.Truncate.NONE)
+        );
     }
 
     public void testTokenizeLargeInputTruncation() {
         BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, false, 5, Tokenization.Truncate.FIRST))
             .build();
 
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            "Elasticsearch fun with Pancake and Godzilla",
+            Tokenization.Truncate.FIRST
+        );
         assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun", "with", "Pancake"));
 
         tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 5, Tokenization.Truncate.FIRST)).build();
-        tokenization = tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla");
+        tokenization = tokenizer.tokenize("Elasticsearch fun with Pancake and Godzilla", Tokenization.Truncate.FIRST);
         assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "Elastic", "##search", "fun", "[SEP]"));
     }
 
     public void testTokenizeAppendSpecialTokens() {
         BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).build();
 
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
         assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "Elastic", "##search", "fun", "[SEP]"));
         assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds());
         assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap());
@@ -107,7 +113,10 @@ public class BertTokenizerTests extends ESTestCase {
             .setWithSpecialTokens(false)
             .build();
 
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            "Elasticsearch " + specialToken + " fun",
+            Tokenization.Truncate.NONE
+        );
         assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", specialToken, "fun"));
         assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.getTokenIds());
         assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.getTokenMap());
@@ -120,12 +129,12 @@ public class BertTokenizerTests extends ESTestCase {
                 Tokenization.createDefault()
             ).setDoLowerCase(false).setWithSpecialTokens(false).build();
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
             assertThat(tokenization.getTokens(), arrayContaining(BertTokenizer.UNKNOWN_TOKEN, "fun"));
             assertArrayEquals(new int[] { 3, 2 }, tokenization.getTokenIds());
             assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenMap());
 
-            tokenization = tokenizer.tokenize("elasticsearch fun");
+            tokenization = tokenizer.tokenize("elasticsearch fun", Tokenization.Truncate.NONE);
             assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun"));
         }
 
@@ -135,7 +144,7 @@ public class BertTokenizerTests extends ESTestCase {
                 .setWithSpecialTokens(false)
                 .build();
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
             assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun"));
         }
     }
@@ -143,12 +152,12 @@ public class BertTokenizerTests extends ESTestCase {
     public void testPunctuation() {
         BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).setWithSpecialTokens(false).build();
 
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE);
         assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "."));
         assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.getTokenIds());
         assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.getTokenMap());
 
-        tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].");
+        tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].", Tokenization.Truncate.NONE);
         assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "[MASK]", "."));
         assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.getTokenIds());
         assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap());
@@ -162,10 +171,10 @@ public class BertTokenizerTests extends ESTestCase {
 
         TokenizationResult tr = tokenizer.buildTokenizationResult(
             List.of(
-                tokenizer.tokenize("Elasticsearch"),
-                tokenizer.tokenize("my little red car"),
-                tokenizer.tokenize("Godzilla day"),
-                tokenizer.tokenize("Godzilla Pancake red car day")
+                tokenizer.tokenize("Elasticsearch", Tokenization.Truncate.NONE),
+                tokenizer.tokenize("my little red car", Tokenization.Truncate.NONE),
+                tokenizer.tokenize("Godzilla day", Tokenization.Truncate.NONE),
+                tokenizer.tokenize("Godzilla Pancake red car day", Tokenization.Truncate.NONE)
             )
         );
         assertThat(tr.getTokenizations(), hasSize(4));
@@ -196,7 +205,11 @@ public class BertTokenizerTests extends ESTestCase {
             .setDoLowerCase(false)
             .setWithSpecialTokens(true)
             .build();
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            "Elasticsearch is fun",
+            "Godzilla my little red car",
+            Tokenization.Truncate.NONE
+        );
         assertThat(
             tokenization.getTokens(),
             arrayContaining(
@@ -222,7 +235,11 @@ public class BertTokenizerTests extends ESTestCase {
         BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 10, Tokenization.Truncate.FIRST))
             .build();
 
-        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car");
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            "Elasticsearch is fun",
+            "Godzilla my little red car",
+            Tokenization.Truncate.FIRST
+        );
         assertThat(
             tokenization.getTokens(),
             arrayContaining(
@@ -243,12 +260,12 @@ public class BertTokenizerTests extends ESTestCase {
             ElasticsearchStatusException.class,
             () -> BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 8, Tokenization.Truncate.NONE))
                 .build()
-                .tokenize("Elasticsearch is fun", "Godzilla my little red car")
+                .tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.NONE)
         );
 
         tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, true, 10, Tokenization.Truncate.SECOND)).build();
 
-        tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car");
+        tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.SECOND);
         assertThat(
             tokenization.getTokens(),
             arrayContaining(
@@ -272,6 +289,6 @@ public class BertTokenizerTests extends ESTestCase {
             .setDoLowerCase(false)
             .setWithSpecialTokens(false)
             .build();
-        expectThrows(Exception.class, () -> tokenizer.tokenize("foo", "foo"));
+        expectThrows(Exception.class, () -> tokenizer.tokenize("foo", "foo", Tokenization.Truncate.NONE));
     }
 }