Ver Fonte

[ML] adds new mpnet tokenization for nlp models (#82234)

This commit adds support for MPNet based models.

MPNet models differ from BERT style models in that:

 - Special tokens are different
 - Input to the model doesn't require token positions.

To configure an MPNet tokenizer for your pytorch MPNet based model:

```
"tokenization": {
  "mpnet": {...}
}
```
The options provided to `mpnet` are the same as the previously supported `bert` configuration.
Benjamin Trent há 3 anos atrás
pai
commit
9dc8aea1cb
27 ficheiros alterados com 1285 adições e 135 exclusões
  1. 20 0
      docs/reference/ml/ml-shared.asciidoc
  2. 138 0
      docs/reference/ml/trained-models/apis/get-trained-models.asciidoc
  3. 138 0
      docs/reference/ml/trained-models/apis/put-trained-models.asciidoc
  4. 26 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  5. 4 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java
  6. 88 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenization.java
  7. 111 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java
  8. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java
  9. 32 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java
  10. 55 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java
  11. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java
  12. 8 7
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java
  13. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java
  14. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java
  15. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java
  16. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java
  17. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java
  18. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java
  19. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java
  20. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java
  21. 5 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  22. 66 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilder.java
  23. 171 77
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  24. 186 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java
  25. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  26. 105 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilderTests.java
  27. 95 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java

+ 20 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -929,6 +929,13 @@ end::inference-config-classification-prediction-field-type[]
 
 tag::inference-config-nlp-tokenization[]
 Indicates the tokenization to perform and the desired settings.
+The default tokenization configuration is `bert`. Valid tokenization
+values are
++
+--
+* `bert`: Use for BERT-style models
+* `mpnet`: Use for MPNet-style models
+--
 end::inference-config-nlp-tokenization[]
 
 tag::inference-config-nlp-tokenization-bert[]
@@ -970,6 +977,19 @@ Specifies the maximum number of tokens allowed to be output by the tokenizer.
 The default for BERT-style tokenization is `512`.
 end::inference-config-nlp-tokenization-bert-max-sequence-length[]
 
+tag::inference-config-nlp-tokenization-mpnet[]
+MPNet-style tokenization is to be performed with the enclosed settings.
+end::inference-config-nlp-tokenization-mpnet[]
+
+tag::inference-config-nlp-tokenization-mpnet-with-special-tokens[]
+Tokenize with special tokens. The tokens typically included in MPNet-style tokenization are:
++
+--
+* `<s>`: The first token of the sequence being classified.
+* `</s>`: Indicates sequence separation.
+--
+end::inference-config-nlp-tokenization-mpnet-with-special-tokens[]
+
 tag::inference-config-nlp-vocabulary[]
 The configuration for retreiving the vocabulary of the model. The vocabulary is
 then used at inference time. This information is usually provided automatically

+ 138 - 0
docs/reference/ml/trained-models/apis/get-trained-models.asciidoc

@@ -202,6 +202,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 `vocabulary`::::
 (Optional, object)
@@ -260,6 +283,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 `vocabulary`::::
 (Optional, object)
@@ -311,6 +357,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 `vocabulary`::::
 (Optional, object)
@@ -385,6 +454,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 
 `vocabulary`::::
@@ -436,6 +528,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 `vocabulary`::::
 (Optional, object)
@@ -502,6 +617,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 `vocabulary`::::
 (Optional, object)

+ 138 - 0
docs/reference/ml/trained-models/apis/put-trained-models.asciidoc

@@ -458,6 +458,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 =====
 
@@ -504,6 +527,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 =====
 
@@ -544,6 +590,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 =====
 
@@ -607,6 +676,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 =====
 `text_embedding`:::
@@ -646,6 +738,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 =====
 `zero_shot_classification`:::
@@ -701,6 +816,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 =====
 ====

+ 26 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -41,6 +41,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpd
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenizationUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
@@ -435,6 +437,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 (p, c) -> BertTokenization.fromXContent(p, (boolean) c)
             )
         );
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                Tokenization.class,
+                MPNetTokenization.NAME,
+                (p, c) -> MPNetTokenization.fromXContent(p, (boolean) c)
+            )
+        );
 
         namedXContent.add(
             new NamedXContentRegistry.Entry(
@@ -443,6 +452,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 (p, c) -> BertTokenizationUpdate.fromXContent(p)
             )
         );
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                TokenizationUpdate.class,
+                MPNetTokenizationUpdate.NAME,
+                (p, c) -> MPNetTokenizationUpdate.fromXContent(p)
+            )
+        );
 
         return namedXContent;
     }
@@ -591,6 +607,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(Tokenization.class, BertTokenization.NAME.getPreferredName(), BertTokenization::new)
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(Tokenization.class, MPNetTokenization.NAME.getPreferredName(), MPNetTokenization::new)
+        );
 
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(
@@ -599,6 +618,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 BertTokenizationUpdate::new
             )
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                TokenizationUpdate.class,
+                MPNetTokenizationUpdate.NAME.getPreferredName(),
+                MPNetTokenizationUpdate::new
+            )
+        );
 
         return namedWriteables;
     }

+ 4 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java

@@ -48,10 +48,6 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
 
     @Override
     public Tokenization apply(Tokenization originalConfig) {
-        if (isNoop()) {
-            return originalConfig;
-        }
-
         if (originalConfig instanceof BertTokenization == false) {
             throw ExceptionsHelper.badRequestException(
                 "Tokenization config of type [{}] can not be updated with a request of type [{}]",
@@ -60,6 +56,10 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
             );
         }
 
+        if (isNoop()) {
+            return originalConfig;
+        }
+
         return new BertTokenization(
             originalConfig.doLowerCase(),
             originalConfig.withSpecialTokens(),

+ 88 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenization.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class MPNetTokenization extends Tokenization {
+
+    public static final ParseField NAME = new ParseField("mpnet");
+
+    public static ConstructingObjectParser<MPNetTokenization, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<MPNetTokenization, Void> parser = new ConstructingObjectParser<>(
+            "mpnet_tokenization",
+            ignoreUnknownFields,
+            a -> new MPNetTokenization(
+                (Boolean) a[0],
+                (Boolean) a[1],
+                (Integer) a[2],
+                a[3] == null ? null : Truncate.fromString((String) a[3])
+            )
+        );
+        Tokenization.declareCommonFields(parser);
+        return parser;
+    }
+
+    private static final ConstructingObjectParser<MPNetTokenization, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<MPNetTokenization, Void> STRICT_PARSER = createParser(false);
+
+    public static MPNetTokenization fromXContent(XContentParser parser, boolean lenient) {
+        return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
+    }
+
+    public MPNetTokenization(
+        @Nullable Boolean doLowerCase,
+        @Nullable Boolean withSpecialTokens,
+        @Nullable Integer maxSequenceLength,
+        @Nullable Truncate truncate
+    ) {
+        super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate);
+    }
+
+    public MPNetTokenization(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+    }
+
+    XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == null || getClass() != o.getClass()) return false;
+        return super.equals(o);
+    }
+
+    @Override
+    public int hashCode() {
+        return super.hashCode();
+    }
+}

+ 111 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class MPNetTokenizationUpdate implements TokenizationUpdate {
+
+    public static final ParseField NAME = MPNetTokenization.NAME;
+
+    public static ConstructingObjectParser<MPNetTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
+        "mpnet_tokenization_update",
+        a -> new MPNetTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]))
+    );
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
+    }
+
+    public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Tokenization.Truncate truncate;
+
+    public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate) {
+        this.truncate = truncate;
+    }
+
+    public MPNetTokenizationUpdate(StreamInput in) throws IOException {
+        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
+    }
+
+    @Override
+    public Tokenization apply(Tokenization originalConfig) {
+        if (originalConfig instanceof MPNetTokenization == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Tokenization config of type [{}] can not be updated with a request of type [{}]",
+                originalConfig.getName(),
+                getName()
+            );
+        }
+
+        if (isNoop()) {
+            return originalConfig;
+        }
+
+        return new MPNetTokenization(
+            originalConfig.doLowerCase(),
+            originalConfig.withSpecialTokens(),
+            originalConfig.maxSequenceLength(),
+            this.truncate
+        );
+    }
+
+    @Override
+    public boolean isNoop() {
+        return truncate == null;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return MPNetTokenization.NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalEnum(truncate);
+    }
+
+    @Override
+    public String getName() {
+        return MPNetTokenization.NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
+        return truncate == that.truncate;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(truncate);
+    }
+}

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java

@@ -50,7 +50,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
     public static FillMaskConfig createRandom() {
         return new FillMaskConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomInt(),
             randomBoolean() ? null : randomAlphaOfLength(5)
         );

+ 32 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java

@@ -0,0 +1,32 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+public final class InferenceConfigTestScaffolding {
+
+    static Tokenization cloneWithNewTruncation(Tokenization tokenization, Tokenization.Truncate truncate) {
+        return tokenization instanceof MPNetTokenization
+            ? new MPNetTokenization(
+                tokenization.doLowerCase(),
+                tokenization.withSpecialTokens(),
+                tokenization.maxSequenceLength(),
+                truncate
+            )
+            : new BertTokenization(
+                tokenization.doLowerCase(),
+                tokenization.withSpecialTokens(),
+                tokenization.maxSequenceLength(),
+                truncate
+            );
+    }
+
+    static TokenizationUpdate createTokenizationUpdate(Tokenization tokenization, Tokenization.Truncate truncate) {
+        return tokenization instanceof MPNetTokenization ? new MPNetTokenizationUpdate(truncate) : new BertTokenizationUpdate(truncate);
+    }
+
+}

+ 55 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java

@@ -0,0 +1,55 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+
+public class MPNetTokenizationTests extends AbstractBWCSerializationTestCase<MPNetTokenization> {
+
+    private boolean lenient;
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected MPNetTokenization doParseInstance(XContentParser parser) throws IOException {
+        return MPNetTokenization.createParser(lenient).apply(parser, null);
+    }
+
+    @Override
+    protected Writeable.Reader<MPNetTokenization> instanceReader() {
+        return MPNetTokenization::new;
+    }
+
+    @Override
+    protected MPNetTokenization createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected MPNetTokenization mutateInstanceForVersion(MPNetTokenization instance, Version version) {
+        return instance;
+    }
+
+    public static MPNetTokenization createRandom() {
+        return new MPNetTokenization(
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomIntBetween(1, 1024),
+            randomBoolean() ? null : randomFrom(Tokenization.Truncate.values())
+        );
+    }
+}

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java

@@ -50,7 +50,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
     public static NerConfig createRandom() {
         return new NerConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomAlphaOfLength(5)
         );

+ 8 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java

@@ -21,6 +21,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.sameInstance;
 
@@ -65,12 +67,7 @@ public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerCo
         );
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
             new NerConfig(
                 originalConfig.getVocabularyConfig(),
@@ -78,7 +75,11 @@ public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerCo
                 originalConfig.getClassificationLabels(),
                 originalConfig.getResultsField()
             ),
-            equalTo(new NerConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate)).build().apply(originalConfig))
+            equalTo(
+                new NerConfigUpdate.Builder().setTokenizationUpdate(createTokenizationUpdate(originalConfig.getTokenization(), truncate))
+                    .build()
+                    .apply(originalConfig)
+            )
         );
     }
 

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java

@@ -50,7 +50,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
     public static PassThroughConfig createRandom() {
         return new PassThroughConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomAlphaOfLength(7)
         );
     }

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java

@@ -21,6 +21,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.sameInstance;
 
@@ -63,18 +65,13 @@ public class PassThroughConfigUpdateTests extends AbstractBWCSerializationTestCa
         );
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
             new PassThroughConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
             equalTo(
-                new PassThroughConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new PassThroughConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
         );
     }

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java

@@ -69,7 +69,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
     public static TextClassificationConfig createRandom() {
         return new TextClassificationConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomList(2, 5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomBoolean() ? -1 : randomIntBetween(1, 10),
             randomBoolean() ? null : randomAlphaOfLength(6)

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java

@@ -23,6 +23,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 
@@ -121,18 +123,13 @@ public class TextClassificationConfigUpdateTests extends AbstractBWCSerializatio
         );
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
             new TextClassificationConfig.Builder(originalConfig).setTokenization(tokenization).build(),
             equalTo(
-                new TextClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new TextClassificationConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
         );
     }

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java

@@ -50,7 +50,7 @@ public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEm
     public static TextEmbeddingConfig createRandom() {
         return new TextEmbeddingConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomAlphaOfLength(7)
         );
     }

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java

@@ -21,6 +21,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.sameInstance;
 
@@ -63,18 +65,13 @@ public class TextEmbeddingConfigUpdateTests extends AbstractBWCSerializationTest
         );
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
             new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
             equalTo(
-                new TextEmbeddingConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new TextEmbeddingConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
         );
     }

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java

@@ -52,7 +52,7 @@ public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCa
         return new ZeroShotClassificationConfig(
             randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomAlphaOfLength(10),
             randomBoolean(),
             randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

@@ -22,6 +22,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 
@@ -137,12 +139,7 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
         );
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
             new ZeroShotClassificationConfig(
                 originalConfig.getClassificationLabels(),
@@ -154,9 +151,9 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
                 originalConfig.getResultsField()
             ),
             equalTo(
-                new ZeroShotClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new ZeroShotClassificationConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
         );
     }

+ 5 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -11,7 +11,7 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
@@ -26,16 +26,16 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     static final String ARG2 = "arg_2";
     static final String ARG3 = "arg_3";
 
-    private final BertTokenizer tokenizer;
+    private final NlpTokenizer tokenizer;
 
-    public BertRequestBuilder(BertTokenizer tokenizer) {
+    public BertRequestBuilder(NlpTokenizer tokenizer) {
         this.tokenizer = tokenizer;
     }
 
     @Override
     public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
         if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
         }
 
         TokenizationResult tokenization = tokenizer.buildTokenizationResult(
@@ -47,7 +47,7 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     @Override
     public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
         if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
         }
         return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
     }

+ 66 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilder.java

@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class MPNetRequestBuilder implements NlpTask.RequestBuilder {
+
+    static final String REQUEST_ID = "request_id";
+    static final String TOKENS = "tokens";
+    static final String ARG1 = "arg_1";
+
+    private final NlpTokenizer tokenizer;
+
+    public MPNetRequestBuilder(NlpTokenizer tokenizer) {
+        this.tokenizer = tokenizer;
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
+        if (tokenizer.getPadTokenId().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
+        }
+
+        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
+            inputs.stream().map(s -> tokenizer.tokenize(s, truncate)).collect(Collectors.toList())
+        );
+        return buildRequest(tokenization, requestId);
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
+        if (tokenizer.getPadTokenId().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
+        }
+        return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
+    }
+
+    static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException {
+        XContentBuilder builder = XContentFactory.jsonBuilder();
+        builder.startObject();
+        builder.field(REQUEST_ID, requestId);
+
+        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
+        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
+        builder.endObject();
+
+        // BytesReference.bytes closes the builder
+        return BytesReference.bytes(builder);
+    }
+
+}

+ 171 - 77
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -20,6 +20,8 @@ import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
 import java.util.function.Function;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
 
 /**
  * Performs basic tokenization and normalization of input text
@@ -41,7 +43,7 @@ public class BertTokenizer implements NlpTokenizer {
 
     public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
 
-    private final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
+    private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 
     private final WordPieceTokenizer wordPieceTokenizer;
     private final List<String> originalVocab;
@@ -50,10 +52,17 @@ public class BertTokenizer implements NlpTokenizer {
     private final boolean doLowerCase;
     private final boolean doTokenizeCjKChars;
     private final boolean doStripAccents;
-    private final boolean withSpecialTokens;
+    protected final boolean withSpecialTokens;
     private final Set<String> neverSplit;
     private final int maxSequenceLength;
     private final NlpTask.RequestBuilder requestBuilder;
+    private final String sepToken;
+    protected final int sepTokenId;
+    private final String clsToken;
+    private final int clsTokenId;
+    private final String padToken;
+    private final String maskToken;
+    private final String unknownToken;
 
     protected BertTokenizer(
         List<String> originalVocab,
@@ -63,37 +72,97 @@ public class BertTokenizer implements NlpTokenizer {
         boolean doStripAccents,
         boolean withSpecialTokens,
         int maxSequenceLength,
-        Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
         Set<String> neverSplit
     ) {
-        wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
+        this(
+            originalVocab,
+            vocab,
+            doLowerCase,
+            doTokenizeCjKChars,
+            doStripAccents,
+            withSpecialTokens,
+            maxSequenceLength,
+            requestBuilderFactory,
+            Sets.union(neverSplit, NEVER_SPLIT),
+            SEPARATOR_TOKEN,
+            CLASS_TOKEN,
+            PAD_TOKEN,
+            MASK_TOKEN,
+            UNKNOWN_TOKEN
+        );
+    }
+
+    protected BertTokenizer(
+        List<String> originalVocab,
+        SortedMap<String, Integer> vocab,
+        boolean doLowerCase,
+        boolean doTokenizeCjKChars,
+        boolean doStripAccents,
+        boolean withSpecialTokens,
+        int maxSequenceLength,
+        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+        Set<String> neverSplit,
+        String sepToken,
+        String clsToken,
+        String padToken,
+        String maskToken,
+        String unknownToken
+    ) {
+        wordPieceTokenizer = new WordPieceTokenizer(vocab, unknownToken, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
         this.originalVocab = originalVocab;
         this.vocab = vocab;
         this.doLowerCase = doLowerCase;
         this.doTokenizeCjKChars = doTokenizeCjKChars;
         this.doStripAccents = doStripAccents;
         this.withSpecialTokens = withSpecialTokens;
-        this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
+        this.neverSplit = neverSplit;
         this.maxSequenceLength = maxSequenceLength;
         this.requestBuilder = requestBuilderFactory.apply(this);
-        if (vocab.containsKey(UNKNOWN_TOKEN) == false) {
-            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", UNKNOWN_TOKEN);
+        if (vocab.containsKey(unknownToken) == false) {
+            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", unknownToken);
         }
-        if (vocab.containsKey(PAD_TOKEN) == false) {
-            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", PAD_TOKEN);
+        if (vocab.containsKey(padToken) == false) {
+            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", padToken);
         }
 
         if (withSpecialTokens) {
-            Set<String> missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet());
+            Set<String> missingSpecialTokens = Sets.difference(Set.of(sepToken, clsToken), vocab.keySet());
             if (missingSpecialTokens.isEmpty() == false) {
                 throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", missingSpecialTokens);
             }
+            this.sepTokenId = vocab.get(sepToken);
+            this.clsTokenId = vocab.get(clsToken);
+        } else {
+            this.sepTokenId = -1;
+            this.clsTokenId = -1;
         }
+        this.sepToken = sepToken;
+        this.clsToken = clsToken;
+        this.padToken = padToken;
+        this.maskToken = maskToken;
+        this.unknownToken = unknownToken;
+    }
+
+    public String getSepToken() {
+        return sepToken;
+    }
+
+    public String getClsToken() {
+        return clsToken;
+    }
+
+    public String getPadToken() {
+        return padToken;
+    }
+
+    public String getUnknownToken() {
+        return unknownToken;
     }
 
     @Override
     public OptionalInt getPadTokenId() {
-        Integer pad = vocab.get(PAD_TOKEN);
+        Integer pad = vocab.get(this.padToken);
         if (pad != null) {
             return OptionalInt.of(pad);
         } else {
@@ -103,7 +172,7 @@ public class BertTokenizer implements NlpTokenizer {
 
     @Override
     public OptionalInt getMaskTokenId() {
-        Integer pad = vocab.get(MASK_TOKEN);
+        Integer pad = vocab.get(this.maskToken);
         if (pad != null) {
             return OptionalInt.of(pad);
         } else {
@@ -113,7 +182,7 @@ public class BertTokenizer implements NlpTokenizer {
 
     @Override
     public String getMaskToken() {
-        return MASK_TOKEN;
+        return maskToken;
     }
 
     @Override
@@ -150,6 +219,7 @@ public class BertTokenizer implements NlpTokenizer {
                 case SECOND:
                     isTruncated = true;
                     wordPieceTokenIds = wordPieceTokenIds.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
+                    tokenPositionMap = tokenPositionMap.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
                     break;
                 case NONE:
                     throw ExceptionsHelper.badRequestException(
@@ -158,31 +228,16 @@ public class BertTokenizer implements NlpTokenizer {
                         maxSequenceLength
                     );
             }
-            numTokens = maxSequenceLength;
-        }
-
-        int[] tokenIds = new int[numTokens];
-        int[] tokenMap = new int[numTokens];
-
-        if (withSpecialTokens) {
-            tokenIds[0] = vocab.get(CLASS_TOKEN);
-            tokenMap[0] = SPECIAL_TOKEN_POSITION;
-        }
-
-        int i = withSpecialTokens ? 1 : 0;
-        final int decrementHandler = withSpecialTokens ? 1 : 0;
-        for (var tokenId : wordPieceTokenIds) {
-            tokenIds[i] = tokenId;
-            tokenMap[i] = tokenPositionMap.get(i - decrementHandler);
-            i++;
-        }
-
-        if (withSpecialTokens) {
-            tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
-            tokenMap[i] = SPECIAL_TOKEN_POSITION;
         }
-
-        return new TokenizationResult.Tokenization(seq, innerResult.tokens, isTruncated, tokenIds, tokenMap);
+        BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(wordPieceTokenIds, tokenPositionMap)
+            .addEndTokensIfNecessary();
+        return new TokenizationResult.Tokenization(
+            seq,
+            innerResult.tokens,
+            isTruncated,
+            bertTokenizationBuilder.buildIds(),
+            bertTokenizationBuilder.buildMap()
+        );
     }
 
     @Override
@@ -196,39 +251,47 @@ public class BertTokenizer implements NlpTokenizer {
         if (withSpecialTokens == false) {
             throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
         }
-        // [CLS] seq1 [SEP] seq2 [SEP]
-        int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + 3;
+        int extraTokens = getNumExtraTokensForSeqPair();
+        int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + extraTokens;
 
         boolean isTruncated = false;
         if (numTokens > maxSequenceLength) {
             switch (truncate) {
                 case FIRST:
                     isTruncated = true;
-                    if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - 3) {
+                    if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - extraTokens) {
                         throw ExceptionsHelper.badRequestException(
                             "Attempting truncation [{}] but input is too large for the second sequence. "
                                 + "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
                                 + "when taking special tokens into account",
                             truncate.toString(),
                             wordPieceTokenIdsSeq2.size(),
-                            maxSequenceLength - 3
+                            maxSequenceLength - extraTokens
                         );
                     }
-                    wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq2.size());
+                    wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(
+                        0,
+                        maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size()
+                    );
+                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size());
                     break;
                 case SECOND:
                     isTruncated = true;
-                    if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - 3) {
+                    if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - extraTokens) {
                         throw ExceptionsHelper.badRequestException(
                             "Attempting truncation [{}] but input is too large for the first sequence. "
                                 + "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
                                 + "when taking special tokens into account",
                             truncate.toString(),
                             wordPieceTokenIdsSeq1.size(),
-                            maxSequenceLength - 3
+                            maxSequenceLength - extraTokens
                         );
                     }
-                    wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq1.size());
+                    wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(
+                        0,
+                        maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size()
+                    );
+                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size());
                     break;
                 case NONE:
                     throw ExceptionsHelper.badRequestException(
@@ -237,38 +300,27 @@ public class BertTokenizer implements NlpTokenizer {
                         maxSequenceLength
                     );
             }
-            numTokens = maxSequenceLength;
-        }
-        int[] tokenIds = new int[numTokens];
-        int[] tokenMap = new int[numTokens];
-
-        tokenIds[0] = vocab.get(CLASS_TOKEN);
-        tokenMap[0] = SPECIAL_TOKEN_POSITION;
-
-        int i = 1;
-        for (var tokenId : wordPieceTokenIdsSeq1) {
-            tokenIds[i] = tokenId;
-            tokenMap[i] = tokenPositionMapSeq1.get(i - 1);
-            i++;
         }
-        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
-        tokenMap[i] = SPECIAL_TOKEN_POSITION;
-        ++i;
-
-        int j = 0;
-        for (var tokenId : wordPieceTokenIdsSeq2) {
-            tokenIds[i] = tokenId;
-            tokenMap[i] = tokenPositionMapSeq2.get(j);
-            i++;
-            j++;
-        }
-
-        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
-        tokenMap[i] = SPECIAL_TOKEN_POSITION;
-
+        BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(wordPieceTokenIdsSeq1, tokenPositionMapSeq1)
+            .addTokens(wordPieceTokenIdsSeq2, tokenPositionMapSeq2)
+            .addEndTokensIfNecessary();
         List<DelimitedToken> tokens = new ArrayList<>(innerResultSeq1.tokens);
         tokens.addAll(innerResultSeq2.tokens);
-        return new TokenizationResult.Tokenization(seq1 + seq2, tokens, isTruncated, tokenIds, tokenMap);
+        return new TokenizationResult.Tokenization(
+            seq1 + seq2,
+            tokens,
+            isTruncated,
+            bertTokenizationBuilder.buildIds(),
+            bertTokenizationBuilder.buildMap()
+        );
+    }
+
+    protected BertTokenizationBuilder bertTokenizationBuilder() {
+        return new BertTokenizationBuilder();
+    }
+
+    protected int getNumExtraTokensForSeqPair() {
+        return 3;
     }
 
     private InnerTokenization innerTokenize(String seq) {
@@ -280,7 +332,7 @@ public class BertTokenizer implements NlpTokenizer {
         for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); sourceIndex++) {
             String token = tokenSequences.get(sourceIndex).getToken();
             if (neverSplit.contains(token)) {
-                wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN)));
+                wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(unknownToken)));
                 tokenPositionMap.add(sourceIndex);
             } else {
                 List<Integer> tokens = wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex));
@@ -319,6 +371,48 @@ public class BertTokenizer implements NlpTokenizer {
         return new Builder(vocab, tokenization);
     }
 
+    protected class BertTokenizationBuilder {
+        Stream.Builder<IntStream> tokenIds;
+        Stream.Builder<IntStream> tokenMap;
+        int numSeq;
+
+        BertTokenizationBuilder() {
+            tokenIds = Stream.builder();
+            tokenMap = Stream.builder();
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(clsTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+        }
+
+        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
+            if (numSeq > 0 && withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
+            numSeq++;
+            return this;
+        }
+
+        BertTokenizationBuilder addEndTokensIfNecessary() {
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            return this;
+        }
+
+        int[] buildIds() {
+            return tokenIds.build().flatMapToInt(Function.identity()).toArray();
+        }
+
+        int[] buildMap() {
+            return tokenMap.build().flatMapToInt(Function.identity()).toArray();
+        }
+    }
+
     public static class Builder {
 
         protected final List<String> originalVocab;
@@ -329,7 +423,7 @@ public class BertTokenizer implements NlpTokenizer {
         protected int maxSequenceLength;
         protected Boolean doStripAccents = null;
         protected Set<String> neverSplit;
-        protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
+        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
 
         protected Builder(List<String> vocab, Tokenization tokenization) {
             this.originalVocab = vocab;
@@ -382,7 +476,7 @@ public class BertTokenizer implements NlpTokenizer {
             return this;
         }
 
-        public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
+        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
             this.requestBuilderFactory = requestBuilderFactory;
             return this;
         }

+ 186 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java

@@ -0,0 +1,186 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.MPNetRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.IntStream;
+
+/**
+ * Performs basic tokenization and normalization of input text
+ * then tokenizes with the WordPiece algorithm using the given
+ * vocabulary.
+ */
+public class MPNetTokenizer extends BertTokenizer {
+
+    public static final String UNKNOWN_TOKEN = "[UNK]";
+    public static final String SEPARATOR_TOKEN = "</s>";
+    public static final String PAD_TOKEN = "<pad>";
+    public static final String CLASS_TOKEN = "<s>";
+    public static final String MASK_TOKEN = "<mask>";
+    private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
+
+    protected MPNetTokenizer(
+        List<String> originalVocab,
+        SortedMap<String, Integer> vocab,
+        boolean doLowerCase,
+        boolean doTokenizeCjKChars,
+        boolean doStripAccents,
+        boolean withSpecialTokens,
+        int maxSequenceLength,
+        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+        Set<String> neverSplit
+    ) {
+        super(
+            originalVocab,
+            vocab,
+            doLowerCase,
+            doTokenizeCjKChars,
+            doStripAccents,
+            withSpecialTokens,
+            maxSequenceLength,
+            requestBuilderFactory,
+            Sets.union(neverSplit, NEVER_SPLIT),
+            SEPARATOR_TOKEN,
+            CLASS_TOKEN,
+            PAD_TOKEN,
+            MASK_TOKEN,
+            UNKNOWN_TOKEN
+        );
+    }
+
+    @Override
+    protected int getNumExtraTokensForSeqPair() {
+        return 4;
+    }
+
+    @Override
+    protected BertTokenizationBuilder bertTokenizationBuilder() {
+        return new MPNetTokenizationBuilder();
+    }
+
+    protected class MPNetTokenizationBuilder extends BertTokenizationBuilder {
+
+        @Override
+        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
+            if (numSeq > 0 && withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId, sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION, SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
+            numSeq++;
+            return this;
+        }
+
+    }
+
+    public static Builder mpBuilder(List<String> vocab, Tokenization tokenization) {
+        return new Builder(vocab, tokenization);
+    }
+
+    public static class Builder {
+
+        protected final List<String> originalVocab;
+        protected final SortedMap<String, Integer> vocab;
+        protected boolean doLowerCase = false;
+        protected boolean doTokenizeCjKChars = true;
+        protected boolean withSpecialTokens = true;
+        protected int maxSequenceLength;
+        protected Boolean doStripAccents = null;
+        protected Set<String> neverSplit;
+        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = MPNetRequestBuilder::new;
+
+        protected Builder(List<String> vocab, Tokenization tokenization) {
+            this.originalVocab = vocab;
+            this.vocab = buildSortedVocab(vocab);
+            this.doLowerCase = tokenization.doLowerCase();
+            this.withSpecialTokens = tokenization.withSpecialTokens();
+            this.maxSequenceLength = tokenization.maxSequenceLength();
+        }
+
+        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
+            SortedMap<String, Integer> sortedVocab = new TreeMap<>();
+            for (int i = 0; i < vocab.size(); i++) {
+                sortedVocab.put(vocab.get(i), i);
+            }
+            return sortedVocab;
+        }
+
+        public Builder setDoLowerCase(boolean doLowerCase) {
+            this.doLowerCase = doLowerCase;
+            return this;
+        }
+
+        public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) {
+            this.doTokenizeCjKChars = doTokenizeCjKChars;
+            return this;
+        }
+
+        public Builder setDoStripAccents(Boolean doStripAccents) {
+            this.doStripAccents = doStripAccents;
+            return this;
+        }
+
+        public Builder setNeverSplit(Set<String> neverSplit) {
+            this.neverSplit = neverSplit;
+            return this;
+        }
+
+        public Builder setMaxSequenceLength(int maxSequenceLength) {
+            this.maxSequenceLength = maxSequenceLength;
+            return this;
+        }
+
+        /**
+         * Include CLS and SEP tokens
+         * @param withSpecialTokens if true include CLS and SEP tokens
+         * @return this
+         */
+        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
+            this.withSpecialTokens = withSpecialTokens;
+            return this;
+        }
+
+        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
+            this.requestBuilderFactory = requestBuilderFactory;
+            return this;
+        }
+
+        public MPNetTokenizer build() {
+            // if not set strip accents defaults to the value of doLowerCase
+            if (doStripAccents == null) {
+                doStripAccents = doLowerCase;
+            }
+
+            if (neverSplit == null) {
+                neverSplit = Collections.emptySet();
+            }
+
+            return new MPNetTokenizer(
+                originalVocab,
+                vocab,
+                doLowerCase,
+                doTokenizeCjKChars,
+                doStripAccents,
+                withSpecialTokens,
+                maxSequenceLength,
+                requestBuilderFactory,
+                neverSplit
+            );
+        }
+    }
+}

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -8,9 +8,11 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.MPNetRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 
@@ -32,6 +34,8 @@ public interface NlpTokenizer {
 
     OptionalInt getPadTokenId();
 
+    String getPadToken();
+
     OptionalInt getMaskTokenId();
 
     String getMaskToken();
@@ -42,6 +46,9 @@ public interface NlpTokenizer {
         if (params instanceof BertTokenization) {
             return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(BertRequestBuilder::new).build();
         }
+        if (params instanceof MPNetTokenization) {
+            return MPNetTokenizer.mpBuilder(vocabulary.get(), params).setRequestBuilderFactory(MPNetRequestBuilder::new).build();
+        }
         throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
     }
 }

+ 105 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilderTests.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizer;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizerTests.TEST_CASED_VOCAB;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.hasSize;
+
+public class MPNetRequestBuilderTests extends ESTestCase {
+
+    public void testBuildRequest() throws IOException {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 512, null)).build();
+
+        MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+        NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(12, 0, 1, 3, 13), firstListItemFromMap("tokens", jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), firstListItemFromMap("arg_1", jsonDocAsMap));
+    }
+
+    @SuppressWarnings("unchecked")
+    private List<Integer> firstListItemFromMap(String name, Map<String, Object> jsonDocAsMap) {
+        return nthListItemFromMap(name, 0, jsonDocAsMap);
+    }
+
+    @SuppressWarnings("unchecked")
+    public static List<Integer> nthListItemFromMap(String name, int n, Map<String, Object> jsonDocAsMap) {
+        return ((List<List<Integer>>) jsonDocAsMap.get(name)).get(n);
+    }
+
+    public void testInputTooLarge() throws IOException {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 5, null)).build();
+        {
+            MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+            ElasticsearchStatusException e = expectThrows(
+                ElasticsearchStatusException.class,
+                () -> requestBuilder.buildRequest(
+                    Collections.singletonList("Elasticsearch fun Elasticsearch fun Elasticsearch fun"),
+                    "request1",
+                    Tokenization.Truncate.NONE
+                )
+            );
+
+            assertThat(
+                e.getMessage(),
+                containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]")
+            );
+        }
+        {
+            MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+            // input will become 3 tokens + the Class and Separator token = 5 which is
+            // our max sequence length
+            requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testBatchWithPadding() throws IOException {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 512, null)).build();
+
+        MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+        NlpTask.Request request = requestBuilder.buildRequest(
+            List.of("Elasticsearch", "my little red car", "Godzilla day"),
+            "request1",
+            Tokenization.Truncate.NONE
+        );
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_1"), hasSize(3));
+
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(12, 0, 1, 13, 19, 19), nthListItemFromMap("tokens", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 19, 19), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(12, 4, 5, 6, 7, 13), nthListItemFromMap("tokens", 1, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 1, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(12, 8, 9, 16, 13, 19), nthListItemFromMap("tokens", 2, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 19), nthListItemFromMap("arg_1", 2, jsonDocAsMap));
+    }
+}

+ 95 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.Matchers.contains;
+
+public class MPNetTokenizerTests extends ESTestCase {
+
+    public static final List<String> TEST_CASED_VOCAB = List.of(
+        "Elastic",
+        "##search",
+        "is",
+        "fun",
+        "my",
+        "little",
+        "red",
+        "car",
+        "God",
+        "##zilla",
+        ".",
+        ",",
+        MPNetTokenizer.CLASS_TOKEN,
+        MPNetTokenizer.SEPARATOR_TOKEN,
+        MPNetTokenizer.MASK_TOKEN,
+        MPNetTokenizer.UNKNOWN_TOKEN,
+        "day",
+        "Pancake",
+        "with",
+        MPNetTokenizer.PAD_TOKEN
+    );
+
+    private List<String> tokenStrings(List<DelimitedToken> tokens) {
+        return tokens.stream().map(DelimitedToken::getToken).collect(Collectors.toList());
+    }
+
+    public void testTokenize() {
+        BertTokenizer tokenizer = MPNetTokenizer.mpBuilder(
+            TEST_CASED_VOCAB,
+            new MPNetTokenization(null, false, null, Tokenization.Truncate.NONE)
+        ).build();
+
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+        assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", "fun"));
+        assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds());
+        assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+    }
+
+    public void testMultiSeqTokenization() {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(
+            TEST_CASED_VOCAB,
+            new MPNetTokenization(null, false, null, Tokenization.Truncate.NONE)
+        ).setDoLowerCase(false).setWithSpecialTokens(true).build();
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            "Elasticsearch is fun",
+            "Godzilla my little red car",
+            Tokenization.Truncate.NONE
+        );
+
+        var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
+        assertThat(
+            tokenStream,
+            contains(
+                MPNetTokenizer.CLASS_TOKEN,
+                "Elastic",
+                "##search",
+                "is",
+                "fun",
+                MPNetTokenizer.SEPARATOR_TOKEN,
+                MPNetTokenizer.SEPARATOR_TOKEN,
+                "God",
+                "##zilla",
+                "my",
+                "little",
+                "red",
+                "car",
+                MPNetTokenizer.SEPARATOR_TOKEN
+            )
+        );
+        assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.getTokenIds());
+    }
+
+}