Browse Source

[ML] adds new mpnet tokenization for nlp models (#82234)

This commit adds support for MPNet based models.

MPNet models differ from BERT style models in that:

 - Special tokens are different
 - Input to the model doesn't require token positions.

To configure an MPNet tokenizer for your pytorch MPNet based model:

```
"tokenization": {
  "mpnet": {...}
}
```
The options provided to `mpnet` are the same as the previously supported `bert` configuration.
Benjamin Trent 3 years ago
parent
commit
9dc8aea1cb
27 changed files with 1285 additions and 135 deletions
  1. 20 0
      docs/reference/ml/ml-shared.asciidoc
  2. 138 0
      docs/reference/ml/trained-models/apis/get-trained-models.asciidoc
  3. 138 0
      docs/reference/ml/trained-models/apis/put-trained-models.asciidoc
  4. 26 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  5. 4 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java
  6. 88 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenization.java
  7. 111 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java
  8. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java
  9. 32 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java
  10. 55 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java
  11. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java
  12. 8 7
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java
  13. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java
  14. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java
  15. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java
  16. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java
  17. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java
  18. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java
  19. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java
  20. 6 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java
  21. 5 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  22. 66 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilder.java
  23. 171 77
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  24. 186 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java
  25. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  26. 105 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilderTests.java
  27. 95 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java

+ 20 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -929,6 +929,13 @@ end::inference-config-classification-prediction-field-type[]
 
 
 tag::inference-config-nlp-tokenization[]
 tag::inference-config-nlp-tokenization[]
 Indicates the tokenization to perform and the desired settings.
 Indicates the tokenization to perform and the desired settings.
+The default tokenization configuration is `bert`. Valid tokenization
+values are
++
+--
+* `bert`: Use for BERT-style models
+* `mpnet`: Use for MPNet-style models
+--
 end::inference-config-nlp-tokenization[]
 end::inference-config-nlp-tokenization[]
 
 
 tag::inference-config-nlp-tokenization-bert[]
 tag::inference-config-nlp-tokenization-bert[]
@@ -970,6 +977,19 @@ Specifies the maximum number of tokens allowed to be output by the tokenizer.
 The default for BERT-style tokenization is `512`.
 The default for BERT-style tokenization is `512`.
 end::inference-config-nlp-tokenization-bert-max-sequence-length[]
 end::inference-config-nlp-tokenization-bert-max-sequence-length[]
 
 
+tag::inference-config-nlp-tokenization-mpnet[]
+MPNet-style tokenization is to be performed with the enclosed settings.
+end::inference-config-nlp-tokenization-mpnet[]
+
+tag::inference-config-nlp-tokenization-mpnet-with-special-tokens[]
+Tokenize with special tokens. The tokens typically included in MPNet-style tokenization are:
++
+--
+* `<s>`: The first token of the sequence being classified.
+* `</s>`: Indicates sequence separation.
+--
+end::inference-config-nlp-tokenization-mpnet-with-special-tokens[]
+
 tag::inference-config-nlp-vocabulary[]
 tag::inference-config-nlp-vocabulary[]
 The configuration for retreiving the vocabulary of the model. The vocabulary is
 The configuration for retreiving the vocabulary of the model. The vocabulary is
 then used at inference time. This information is usually provided automatically
 then used at inference time. This information is usually provided automatically

+ 138 - 0
docs/reference/ml/trained-models/apis/get-trained-models.asciidoc

@@ -202,6 +202,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 =======
 `vocabulary`::::
 `vocabulary`::::
 (Optional, object)
 (Optional, object)
@@ -260,6 +283,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 =======
 `vocabulary`::::
 `vocabulary`::::
 (Optional, object)
 (Optional, object)
@@ -311,6 +357,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 =======
 `vocabulary`::::
 `vocabulary`::::
 (Optional, object)
 (Optional, object)
@@ -385,6 +454,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 =======
 
 
 `vocabulary`::::
 `vocabulary`::::
@@ -436,6 +528,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 =======
 `vocabulary`::::
 `vocabulary`::::
 (Optional, object)
 (Optional, object)
@@ -502,6 +617,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 ========
 ========
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+========
 =======
 =======
 `vocabulary`::::
 `vocabulary`::::
 (Optional, object)
 (Optional, object)

+ 138 - 0
docs/reference/ml/trained-models/apis/put-trained-models.asciidoc

@@ -458,6 +458,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 ======
 =====
 =====
 
 
@@ -504,6 +527,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 ======
 =====
 =====
 
 
@@ -544,6 +590,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 ======
 =====
 =====
 
 
@@ -607,6 +676,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 ======
 =====
 =====
 `text_embedding`:::
 `text_embedding`:::
@@ -646,6 +738,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 ======
 =====
 =====
 `zero_shot_classification`:::
 `zero_shot_classification`:::
@@ -701,6 +816,29 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 =======
 =======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-truncate]
+
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet-with-special-tokens]
+=======
 ======
 ======
 =====
 =====
 ====
 ====

+ 26 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -41,6 +41,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpd
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenizationUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
@@ -435,6 +437,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 (p, c) -> BertTokenization.fromXContent(p, (boolean) c)
                 (p, c) -> BertTokenization.fromXContent(p, (boolean) c)
             )
             )
         );
         );
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                Tokenization.class,
+                MPNetTokenization.NAME,
+                (p, c) -> MPNetTokenization.fromXContent(p, (boolean) c)
+            )
+        );
 
 
         namedXContent.add(
         namedXContent.add(
             new NamedXContentRegistry.Entry(
             new NamedXContentRegistry.Entry(
@@ -443,6 +452,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 (p, c) -> BertTokenizationUpdate.fromXContent(p)
                 (p, c) -> BertTokenizationUpdate.fromXContent(p)
             )
             )
         );
         );
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                TokenizationUpdate.class,
+                MPNetTokenizationUpdate.NAME,
+                (p, c) -> MPNetTokenizationUpdate.fromXContent(p)
+            )
+        );
 
 
         return namedXContent;
         return namedXContent;
     }
     }
@@ -591,6 +607,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(Tokenization.class, BertTokenization.NAME.getPreferredName(), BertTokenization::new)
             new NamedWriteableRegistry.Entry(Tokenization.class, BertTokenization.NAME.getPreferredName(), BertTokenization::new)
         );
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(Tokenization.class, MPNetTokenization.NAME.getPreferredName(), MPNetTokenization::new)
+        );
 
 
         namedWriteables.add(
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(
             new NamedWriteableRegistry.Entry(
@@ -599,6 +618,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 BertTokenizationUpdate::new
                 BertTokenizationUpdate::new
             )
             )
         );
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                TokenizationUpdate.class,
+                MPNetTokenizationUpdate.NAME.getPreferredName(),
+                MPNetTokenizationUpdate::new
+            )
+        );
 
 
         return namedWriteables;
         return namedWriteables;
     }
     }

+ 4 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java

@@ -48,10 +48,6 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
 
 
     @Override
     @Override
     public Tokenization apply(Tokenization originalConfig) {
     public Tokenization apply(Tokenization originalConfig) {
-        if (isNoop()) {
-            return originalConfig;
-        }
-
         if (originalConfig instanceof BertTokenization == false) {
         if (originalConfig instanceof BertTokenization == false) {
             throw ExceptionsHelper.badRequestException(
             throw ExceptionsHelper.badRequestException(
                 "Tokenization config of type [{}] can not be updated with a request of type [{}]",
                 "Tokenization config of type [{}] can not be updated with a request of type [{}]",
@@ -60,6 +56,10 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
             );
             );
         }
         }
 
 
+        if (isNoop()) {
+            return originalConfig;
+        }
+
         return new BertTokenization(
         return new BertTokenization(
             originalConfig.doLowerCase(),
             originalConfig.doLowerCase(),
             originalConfig.withSpecialTokens(),
             originalConfig.withSpecialTokens(),

+ 88 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenization.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class MPNetTokenization extends Tokenization {
+
+    public static final ParseField NAME = new ParseField("mpnet");
+
+    public static ConstructingObjectParser<MPNetTokenization, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<MPNetTokenization, Void> parser = new ConstructingObjectParser<>(
+            "mpnet_tokenization",
+            ignoreUnknownFields,
+            a -> new MPNetTokenization(
+                (Boolean) a[0],
+                (Boolean) a[1],
+                (Integer) a[2],
+                a[3] == null ? null : Truncate.fromString((String) a[3])
+            )
+        );
+        Tokenization.declareCommonFields(parser);
+        return parser;
+    }
+
+    private static final ConstructingObjectParser<MPNetTokenization, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<MPNetTokenization, Void> STRICT_PARSER = createParser(false);
+
+    public static MPNetTokenization fromXContent(XContentParser parser, boolean lenient) {
+        return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
+    }
+
+    public MPNetTokenization(
+        @Nullable Boolean doLowerCase,
+        @Nullable Boolean withSpecialTokens,
+        @Nullable Integer maxSequenceLength,
+        @Nullable Truncate truncate
+    ) {
+        super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate);
+    }
+
+    public MPNetTokenization(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+    }
+
+    XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == null || getClass() != o.getClass()) return false;
+        return super.equals(o);
+    }
+
+    @Override
+    public int hashCode() {
+        return super.hashCode();
+    }
+}

+ 111 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class MPNetTokenizationUpdate implements TokenizationUpdate {
+
+    public static final ParseField NAME = MPNetTokenization.NAME;
+
+    public static ConstructingObjectParser<MPNetTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
+        "mpnet_tokenization_update",
+        a -> new MPNetTokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]))
+    );
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
+    }
+
+    public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Tokenization.Truncate truncate;
+
+    public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate) {
+        this.truncate = truncate;
+    }
+
+    public MPNetTokenizationUpdate(StreamInput in) throws IOException {
+        this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
+    }
+
+    @Override
+    public Tokenization apply(Tokenization originalConfig) {
+        if (originalConfig instanceof MPNetTokenization == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Tokenization config of type [{}] can not be updated with a request of type [{}]",
+                originalConfig.getName(),
+                getName()
+            );
+        }
+
+        if (isNoop()) {
+            return originalConfig;
+        }
+
+        return new MPNetTokenization(
+            originalConfig.doLowerCase(),
+            originalConfig.withSpecialTokens(),
+            originalConfig.maxSequenceLength(),
+            this.truncate
+        );
+    }
+
+    @Override
+    public boolean isNoop() {
+        return truncate == null;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return MPNetTokenization.NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalEnum(truncate);
+    }
+
+    @Override
+    public String getName() {
+        return MPNetTokenization.NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
+        return truncate == that.truncate;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(truncate);
+    }
+}

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java

@@ -50,7 +50,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
     public static FillMaskConfig createRandom() {
     public static FillMaskConfig createRandom() {
         return new FillMaskConfig(
         return new FillMaskConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomInt(),
             randomBoolean() ? null : randomInt(),
             randomBoolean() ? null : randomAlphaOfLength(5)
             randomBoolean() ? null : randomAlphaOfLength(5)
         );
         );

+ 32 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigTestScaffolding.java

@@ -0,0 +1,32 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+public final class InferenceConfigTestScaffolding {
+
+    static Tokenization cloneWithNewTruncation(Tokenization tokenization, Tokenization.Truncate truncate) {
+        return tokenization instanceof MPNetTokenization
+            ? new MPNetTokenization(
+                tokenization.doLowerCase(),
+                tokenization.withSpecialTokens(),
+                tokenization.maxSequenceLength(),
+                truncate
+            )
+            : new BertTokenization(
+                tokenization.doLowerCase(),
+                tokenization.withSpecialTokens(),
+                tokenization.maxSequenceLength(),
+                truncate
+            );
+    }
+
+    static TokenizationUpdate createTokenizationUpdate(Tokenization tokenization, Tokenization.Truncate truncate) {
+        return tokenization instanceof MPNetTokenization ? new MPNetTokenizationUpdate(truncate) : new BertTokenizationUpdate(truncate);
+    }
+
+}

+ 55 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationTests.java

@@ -0,0 +1,55 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+
+public class MPNetTokenizationTests extends AbstractBWCSerializationTestCase<MPNetTokenization> {
+
+    private boolean lenient;
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected MPNetTokenization doParseInstance(XContentParser parser) throws IOException {
+        return MPNetTokenization.createParser(lenient).apply(parser, null);
+    }
+
+    @Override
+    protected Writeable.Reader<MPNetTokenization> instanceReader() {
+        return MPNetTokenization::new;
+    }
+
+    @Override
+    protected MPNetTokenization createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected MPNetTokenization mutateInstanceForVersion(MPNetTokenization instance, Version version) {
+        return instance;
+    }
+
+    public static MPNetTokenization createRandom() {
+        return new MPNetTokenization(
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomIntBetween(1, 1024),
+            randomBoolean() ? null : randomFrom(Tokenization.Truncate.values())
+        );
+    }
+}

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java

@@ -50,7 +50,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
     public static NerConfig createRandom() {
     public static NerConfig createRandom() {
         return new NerConfig(
         return new NerConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomAlphaOfLength(5)
             randomBoolean() ? null : randomAlphaOfLength(5)
         );
         );

+ 8 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java

@@ -21,6 +21,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.sameInstance;
 import static org.hamcrest.Matchers.sameInstance;
 
 
@@ -65,12 +67,7 @@ public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerCo
         );
         );
 
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
         assertThat(
             new NerConfig(
             new NerConfig(
                 originalConfig.getVocabularyConfig(),
                 originalConfig.getVocabularyConfig(),
@@ -78,7 +75,11 @@ public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerCo
                 originalConfig.getClassificationLabels(),
                 originalConfig.getClassificationLabels(),
                 originalConfig.getResultsField()
                 originalConfig.getResultsField()
             ),
             ),
-            equalTo(new NerConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate)).build().apply(originalConfig))
+            equalTo(
+                new NerConfigUpdate.Builder().setTokenizationUpdate(createTokenizationUpdate(originalConfig.getTokenization(), truncate))
+                    .build()
+                    .apply(originalConfig)
+            )
         );
         );
     }
     }
 
 

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java

@@ -50,7 +50,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
     public static PassThroughConfig createRandom() {
     public static PassThroughConfig createRandom() {
         return new PassThroughConfig(
         return new PassThroughConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomAlphaOfLength(7)
             randomBoolean() ? null : randomAlphaOfLength(7)
         );
         );
     }
     }

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java

@@ -21,6 +21,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.sameInstance;
 import static org.hamcrest.Matchers.sameInstance;
 
 
@@ -63,18 +65,13 @@ public class PassThroughConfigUpdateTests extends AbstractBWCSerializationTestCa
         );
         );
 
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
         assertThat(
             new PassThroughConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
             new PassThroughConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
             equalTo(
             equalTo(
-                new PassThroughConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new PassThroughConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
             )
         );
         );
     }
     }

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java

@@ -69,7 +69,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
     public static TextClassificationConfig createRandom() {
     public static TextClassificationConfig createRandom() {
         return new TextClassificationConfig(
         return new TextClassificationConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomList(2, 5, () -> randomAlphaOfLength(10)),
             randomList(2, 5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomBoolean() ? -1 : randomIntBetween(1, 10),
             randomBoolean() ? null : randomBoolean() ? -1 : randomIntBetween(1, 10),
             randomBoolean() ? null : randomAlphaOfLength(6)
             randomBoolean() ? null : randomAlphaOfLength(6)

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java

@@ -23,6 +23,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 
 
@@ -121,18 +123,13 @@ public class TextClassificationConfigUpdateTests extends AbstractBWCSerializatio
         );
         );
 
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
         assertThat(
             new TextClassificationConfig.Builder(originalConfig).setTokenization(tokenization).build(),
             new TextClassificationConfig.Builder(originalConfig).setTokenization(tokenization).build(),
             equalTo(
             equalTo(
-                new TextClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new TextClassificationConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
             )
         );
         );
     }
     }

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java

@@ -50,7 +50,7 @@ public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEm
     public static TextEmbeddingConfig createRandom() {
     public static TextEmbeddingConfig createRandom() {
         return new TextEmbeddingConfig(
         return new TextEmbeddingConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomBoolean() ? null : randomAlphaOfLength(7)
             randomBoolean() ? null : randomAlphaOfLength(7)
         );
         );
     }
     }

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java

@@ -21,6 +21,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.sameInstance;
 import static org.hamcrest.Matchers.sameInstance;
 
 
@@ -63,18 +65,13 @@ public class TextEmbeddingConfigUpdateTests extends AbstractBWCSerializationTest
         );
         );
 
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
         assertThat(
             new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
             new TextEmbeddingConfig(originalConfig.getVocabularyConfig(), tokenization, originalConfig.getResultsField()),
             equalTo(
             equalTo(
-                new TextEmbeddingConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new TextEmbeddingConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
             )
         );
         );
     }
     }

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java

@@ -52,7 +52,7 @@ public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCa
         return new ZeroShotClassificationConfig(
         return new ZeroShotClassificationConfig(
             randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
             randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomFrom(BertTokenizationTests.createRandom(), MPNetTokenizationTests.createRandom()),
             randomAlphaOfLength(10),
             randomAlphaOfLength(10),
             randomBoolean(),
             randomBoolean(),
             randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),

+ 6 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

@@ -22,6 +22,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.cloneWithNewTruncation;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigTestScaffolding.createTokenizationUpdate;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 
 
@@ -137,12 +139,7 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
         );
         );
 
 
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
         Tokenization.Truncate truncate = randomFrom(Tokenization.Truncate.values());
-        Tokenization tokenization = new BertTokenization(
-            originalConfig.getTokenization().doLowerCase(),
-            originalConfig.getTokenization().withSpecialTokens(),
-            originalConfig.getTokenization().maxSequenceLength(),
-            truncate
-        );
+        Tokenization tokenization = cloneWithNewTruncation(originalConfig.getTokenization(), truncate);
         assertThat(
         assertThat(
             new ZeroShotClassificationConfig(
             new ZeroShotClassificationConfig(
                 originalConfig.getClassificationLabels(),
                 originalConfig.getClassificationLabels(),
@@ -154,9 +151,9 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
                 originalConfig.getResultsField()
                 originalConfig.getResultsField()
             ),
             ),
             equalTo(
             equalTo(
-                new ZeroShotClassificationConfigUpdate.Builder().setTokenizationUpdate(new BertTokenizationUpdate(truncate))
-                    .build()
-                    .apply(originalConfig)
+                new ZeroShotClassificationConfigUpdate.Builder().setTokenizationUpdate(
+                    createTokenizationUpdate(originalConfig.getTokenization(), truncate)
+                ).build().apply(originalConfig)
             )
             )
         );
         );
     }
     }

+ 5 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -11,7 +11,7 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 
 import java.io.IOException;
 import java.io.IOException;
@@ -26,16 +26,16 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     static final String ARG2 = "arg_2";
     static final String ARG2 = "arg_2";
     static final String ARG3 = "arg_3";
     static final String ARG3 = "arg_3";
 
 
-    private final BertTokenizer tokenizer;
+    private final NlpTokenizer tokenizer;
 
 
-    public BertRequestBuilder(BertTokenizer tokenizer) {
+    public BertRequestBuilder(NlpTokenizer tokenizer) {
         this.tokenizer = tokenizer;
         this.tokenizer = tokenizer;
     }
     }
 
 
     @Override
     @Override
     public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
     public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
         if (tokenizer.getPadTokenId().isEmpty()) {
         if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
         }
         }
 
 
         TokenizationResult tokenization = tokenizer.buildTokenizationResult(
         TokenizationResult tokenization = tokenizer.buildTokenizationResult(
@@ -47,7 +47,7 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     @Override
     @Override
     public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
     public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
         if (tokenizer.getPadTokenId().isEmpty()) {
         if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN + " token in its vocabulary");
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
         }
         }
         return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
         return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
     }
     }

+ 66 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilder.java

@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class MPNetRequestBuilder implements NlpTask.RequestBuilder {
+
+    static final String REQUEST_ID = "request_id";
+    static final String TOKENS = "tokens";
+    static final String ARG1 = "arg_1";
+
+    private final NlpTokenizer tokenizer;
+
+    public MPNetRequestBuilder(NlpTokenizer tokenizer) {
+        this.tokenizer = tokenizer;
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
+        if (tokenizer.getPadTokenId().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
+        }
+
+        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
+            inputs.stream().map(s -> tokenizer.tokenize(s, truncate)).collect(Collectors.toList())
+        );
+        return buildRequest(tokenization, requestId);
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
+        if (tokenizer.getPadTokenId().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
+        }
+        return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
+    }
+
+    static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException {
+        XContentBuilder builder = XContentFactory.jsonBuilder();
+        builder.startObject();
+        builder.field(REQUEST_ID, requestId);
+
+        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
+        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
+        builder.endObject();
+
+        // BytesReference.bytes closes the builder
+        return BytesReference.bytes(builder);
+    }
+
+}

+ 171 - 77
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -20,6 +20,8 @@ import java.util.Set;
 import java.util.SortedMap;
 import java.util.SortedMap;
 import java.util.TreeMap;
 import java.util.TreeMap;
 import java.util.function.Function;
 import java.util.function.Function;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
 
 
 /**
 /**
  * Performs basic tokenization and normalization of input text
  * Performs basic tokenization and normalization of input text
@@ -41,7 +43,7 @@ public class BertTokenizer implements NlpTokenizer {
 
 
     public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
     public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
 
 
-    private final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
+    private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 
 
     private final WordPieceTokenizer wordPieceTokenizer;
     private final WordPieceTokenizer wordPieceTokenizer;
     private final List<String> originalVocab;
     private final List<String> originalVocab;
@@ -50,10 +52,17 @@ public class BertTokenizer implements NlpTokenizer {
     private final boolean doLowerCase;
     private final boolean doLowerCase;
     private final boolean doTokenizeCjKChars;
     private final boolean doTokenizeCjKChars;
     private final boolean doStripAccents;
     private final boolean doStripAccents;
-    private final boolean withSpecialTokens;
+    protected final boolean withSpecialTokens;
     private final Set<String> neverSplit;
     private final Set<String> neverSplit;
     private final int maxSequenceLength;
     private final int maxSequenceLength;
     private final NlpTask.RequestBuilder requestBuilder;
     private final NlpTask.RequestBuilder requestBuilder;
+    private final String sepToken;
+    protected final int sepTokenId;
+    private final String clsToken;
+    private final int clsTokenId;
+    private final String padToken;
+    private final String maskToken;
+    private final String unknownToken;
 
 
     protected BertTokenizer(
     protected BertTokenizer(
         List<String> originalVocab,
         List<String> originalVocab,
@@ -63,37 +72,97 @@ public class BertTokenizer implements NlpTokenizer {
         boolean doStripAccents,
         boolean doStripAccents,
         boolean withSpecialTokens,
         boolean withSpecialTokens,
         int maxSequenceLength,
         int maxSequenceLength,
-        Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
         Set<String> neverSplit
         Set<String> neverSplit
     ) {
     ) {
-        wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
+        this(
+            originalVocab,
+            vocab,
+            doLowerCase,
+            doTokenizeCjKChars,
+            doStripAccents,
+            withSpecialTokens,
+            maxSequenceLength,
+            requestBuilderFactory,
+            Sets.union(neverSplit, NEVER_SPLIT),
+            SEPARATOR_TOKEN,
+            CLASS_TOKEN,
+            PAD_TOKEN,
+            MASK_TOKEN,
+            UNKNOWN_TOKEN
+        );
+    }
+
+    protected BertTokenizer(
+        List<String> originalVocab,
+        SortedMap<String, Integer> vocab,
+        boolean doLowerCase,
+        boolean doTokenizeCjKChars,
+        boolean doStripAccents,
+        boolean withSpecialTokens,
+        int maxSequenceLength,
+        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+        Set<String> neverSplit,
+        String sepToken,
+        String clsToken,
+        String padToken,
+        String maskToken,
+        String unknownToken
+    ) {
+        wordPieceTokenizer = new WordPieceTokenizer(vocab, unknownToken, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
         this.originalVocab = originalVocab;
         this.originalVocab = originalVocab;
         this.vocab = vocab;
         this.vocab = vocab;
         this.doLowerCase = doLowerCase;
         this.doLowerCase = doLowerCase;
         this.doTokenizeCjKChars = doTokenizeCjKChars;
         this.doTokenizeCjKChars = doTokenizeCjKChars;
         this.doStripAccents = doStripAccents;
         this.doStripAccents = doStripAccents;
         this.withSpecialTokens = withSpecialTokens;
         this.withSpecialTokens = withSpecialTokens;
-        this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
+        this.neverSplit = neverSplit;
         this.maxSequenceLength = maxSequenceLength;
         this.maxSequenceLength = maxSequenceLength;
         this.requestBuilder = requestBuilderFactory.apply(this);
         this.requestBuilder = requestBuilderFactory.apply(this);
-        if (vocab.containsKey(UNKNOWN_TOKEN) == false) {
-            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", UNKNOWN_TOKEN);
+        if (vocab.containsKey(unknownToken) == false) {
+            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", unknownToken);
         }
         }
-        if (vocab.containsKey(PAD_TOKEN) == false) {
-            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", PAD_TOKEN);
+        if (vocab.containsKey(padToken) == false) {
+            throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", padToken);
         }
         }
 
 
         if (withSpecialTokens) {
         if (withSpecialTokens) {
-            Set<String> missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet());
+            Set<String> missingSpecialTokens = Sets.difference(Set.of(sepToken, clsToken), vocab.keySet());
             if (missingSpecialTokens.isEmpty() == false) {
             if (missingSpecialTokens.isEmpty() == false) {
                 throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", missingSpecialTokens);
                 throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", missingSpecialTokens);
             }
             }
+            this.sepTokenId = vocab.get(sepToken);
+            this.clsTokenId = vocab.get(clsToken);
+        } else {
+            this.sepTokenId = -1;
+            this.clsTokenId = -1;
         }
         }
+        this.sepToken = sepToken;
+        this.clsToken = clsToken;
+        this.padToken = padToken;
+        this.maskToken = maskToken;
+        this.unknownToken = unknownToken;
+    }
+
+    public String getSepToken() {
+        return sepToken;
+    }
+
+    public String getClsToken() {
+        return clsToken;
+    }
+
+    public String getPadToken() {
+        return padToken;
+    }
+
+    public String getUnknownToken() {
+        return unknownToken;
     }
     }
 
 
     @Override
     @Override
     public OptionalInt getPadTokenId() {
     public OptionalInt getPadTokenId() {
-        Integer pad = vocab.get(PAD_TOKEN);
+        Integer pad = vocab.get(this.padToken);
         if (pad != null) {
         if (pad != null) {
             return OptionalInt.of(pad);
             return OptionalInt.of(pad);
         } else {
         } else {
@@ -103,7 +172,7 @@ public class BertTokenizer implements NlpTokenizer {
 
 
     @Override
     @Override
     public OptionalInt getMaskTokenId() {
     public OptionalInt getMaskTokenId() {
-        Integer pad = vocab.get(MASK_TOKEN);
+        Integer pad = vocab.get(this.maskToken);
         if (pad != null) {
         if (pad != null) {
             return OptionalInt.of(pad);
             return OptionalInt.of(pad);
         } else {
         } else {
@@ -113,7 +182,7 @@ public class BertTokenizer implements NlpTokenizer {
 
 
     @Override
     @Override
     public String getMaskToken() {
     public String getMaskToken() {
-        return MASK_TOKEN;
+        return maskToken;
     }
     }
 
 
     @Override
     @Override
@@ -150,6 +219,7 @@ public class BertTokenizer implements NlpTokenizer {
                 case SECOND:
                 case SECOND:
                     isTruncated = true;
                     isTruncated = true;
                     wordPieceTokenIds = wordPieceTokenIds.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
                     wordPieceTokenIds = wordPieceTokenIds.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
+                    tokenPositionMap = tokenPositionMap.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
                     break;
                     break;
                 case NONE:
                 case NONE:
                     throw ExceptionsHelper.badRequestException(
                     throw ExceptionsHelper.badRequestException(
@@ -158,31 +228,16 @@ public class BertTokenizer implements NlpTokenizer {
                         maxSequenceLength
                         maxSequenceLength
                     );
                     );
             }
             }
-            numTokens = maxSequenceLength;
-        }
-
-        int[] tokenIds = new int[numTokens];
-        int[] tokenMap = new int[numTokens];
-
-        if (withSpecialTokens) {
-            tokenIds[0] = vocab.get(CLASS_TOKEN);
-            tokenMap[0] = SPECIAL_TOKEN_POSITION;
-        }
-
-        int i = withSpecialTokens ? 1 : 0;
-        final int decrementHandler = withSpecialTokens ? 1 : 0;
-        for (var tokenId : wordPieceTokenIds) {
-            tokenIds[i] = tokenId;
-            tokenMap[i] = tokenPositionMap.get(i - decrementHandler);
-            i++;
-        }
-
-        if (withSpecialTokens) {
-            tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
-            tokenMap[i] = SPECIAL_TOKEN_POSITION;
         }
         }
-
-        return new TokenizationResult.Tokenization(seq, innerResult.tokens, isTruncated, tokenIds, tokenMap);
+        BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(wordPieceTokenIds, tokenPositionMap)
+            .addEndTokensIfNecessary();
+        return new TokenizationResult.Tokenization(
+            seq,
+            innerResult.tokens,
+            isTruncated,
+            bertTokenizationBuilder.buildIds(),
+            bertTokenizationBuilder.buildMap()
+        );
     }
     }
 
 
     @Override
     @Override
@@ -196,39 +251,47 @@ public class BertTokenizer implements NlpTokenizer {
         if (withSpecialTokens == false) {
         if (withSpecialTokens == false) {
             throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
             throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
         }
         }
-        // [CLS] seq1 [SEP] seq2 [SEP]
-        int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + 3;
+        int extraTokens = getNumExtraTokensForSeqPair();
+        int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + extraTokens;
 
 
         boolean isTruncated = false;
         boolean isTruncated = false;
         if (numTokens > maxSequenceLength) {
         if (numTokens > maxSequenceLength) {
             switch (truncate) {
             switch (truncate) {
                 case FIRST:
                 case FIRST:
                     isTruncated = true;
                     isTruncated = true;
-                    if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - 3) {
+                    if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - extraTokens) {
                         throw ExceptionsHelper.badRequestException(
                         throw ExceptionsHelper.badRequestException(
                             "Attempting truncation [{}] but input is too large for the second sequence. "
                             "Attempting truncation [{}] but input is too large for the second sequence. "
                                 + "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
                                 + "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
                                 + "when taking special tokens into account",
                                 + "when taking special tokens into account",
                             truncate.toString(),
                             truncate.toString(),
                             wordPieceTokenIdsSeq2.size(),
                             wordPieceTokenIdsSeq2.size(),
-                            maxSequenceLength - 3
+                            maxSequenceLength - extraTokens
                         );
                         );
                     }
                     }
-                    wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq2.size());
+                    wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(
+                        0,
+                        maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size()
+                    );
+                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size());
                     break;
                     break;
                 case SECOND:
                 case SECOND:
                     isTruncated = true;
                     isTruncated = true;
-                    if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - 3) {
+                    if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - extraTokens) {
                         throw ExceptionsHelper.badRequestException(
                         throw ExceptionsHelper.badRequestException(
                             "Attempting truncation [{}] but input is too large for the first sequence. "
                             "Attempting truncation [{}] but input is too large for the first sequence. "
                                 + "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
                                 + "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
                                 + "when taking special tokens into account",
                                 + "when taking special tokens into account",
                             truncate.toString(),
                             truncate.toString(),
                             wordPieceTokenIdsSeq1.size(),
                             wordPieceTokenIdsSeq1.size(),
-                            maxSequenceLength - 3
+                            maxSequenceLength - extraTokens
                         );
                         );
                     }
                     }
-                    wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq1.size());
+                    wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(
+                        0,
+                        maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size()
+                    );
+                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size());
                     break;
                     break;
                 case NONE:
                 case NONE:
                     throw ExceptionsHelper.badRequestException(
                     throw ExceptionsHelper.badRequestException(
@@ -237,38 +300,27 @@ public class BertTokenizer implements NlpTokenizer {
                         maxSequenceLength
                         maxSequenceLength
                     );
                     );
             }
             }
-            numTokens = maxSequenceLength;
-        }
-        int[] tokenIds = new int[numTokens];
-        int[] tokenMap = new int[numTokens];
-
-        tokenIds[0] = vocab.get(CLASS_TOKEN);
-        tokenMap[0] = SPECIAL_TOKEN_POSITION;
-
-        int i = 1;
-        for (var tokenId : wordPieceTokenIdsSeq1) {
-            tokenIds[i] = tokenId;
-            tokenMap[i] = tokenPositionMapSeq1.get(i - 1);
-            i++;
         }
         }
-        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
-        tokenMap[i] = SPECIAL_TOKEN_POSITION;
-        ++i;
-
-        int j = 0;
-        for (var tokenId : wordPieceTokenIdsSeq2) {
-            tokenIds[i] = tokenId;
-            tokenMap[i] = tokenPositionMapSeq2.get(j);
-            i++;
-            j++;
-        }
-
-        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
-        tokenMap[i] = SPECIAL_TOKEN_POSITION;
-
+        BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(wordPieceTokenIdsSeq1, tokenPositionMapSeq1)
+            .addTokens(wordPieceTokenIdsSeq2, tokenPositionMapSeq2)
+            .addEndTokensIfNecessary();
         List<DelimitedToken> tokens = new ArrayList<>(innerResultSeq1.tokens);
         List<DelimitedToken> tokens = new ArrayList<>(innerResultSeq1.tokens);
         tokens.addAll(innerResultSeq2.tokens);
         tokens.addAll(innerResultSeq2.tokens);
-        return new TokenizationResult.Tokenization(seq1 + seq2, tokens, isTruncated, tokenIds, tokenMap);
+        return new TokenizationResult.Tokenization(
+            seq1 + seq2,
+            tokens,
+            isTruncated,
+            bertTokenizationBuilder.buildIds(),
+            bertTokenizationBuilder.buildMap()
+        );
+    }
+
+    protected BertTokenizationBuilder bertTokenizationBuilder() {
+        return new BertTokenizationBuilder();
+    }
+
+    protected int getNumExtraTokensForSeqPair() {
+        return 3;
     }
     }
 
 
     private InnerTokenization innerTokenize(String seq) {
     private InnerTokenization innerTokenize(String seq) {
@@ -280,7 +332,7 @@ public class BertTokenizer implements NlpTokenizer {
         for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); sourceIndex++) {
         for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); sourceIndex++) {
             String token = tokenSequences.get(sourceIndex).getToken();
             String token = tokenSequences.get(sourceIndex).getToken();
             if (neverSplit.contains(token)) {
             if (neverSplit.contains(token)) {
-                wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN)));
+                wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(unknownToken)));
                 tokenPositionMap.add(sourceIndex);
                 tokenPositionMap.add(sourceIndex);
             } else {
             } else {
                 List<Integer> tokens = wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex));
                 List<Integer> tokens = wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex));
@@ -319,6 +371,48 @@ public class BertTokenizer implements NlpTokenizer {
         return new Builder(vocab, tokenization);
         return new Builder(vocab, tokenization);
     }
     }
 
 
+    protected class BertTokenizationBuilder {
+        Stream.Builder<IntStream> tokenIds;
+        Stream.Builder<IntStream> tokenMap;
+        int numSeq;
+
+        BertTokenizationBuilder() {
+            tokenIds = Stream.builder();
+            tokenMap = Stream.builder();
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(clsTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+        }
+
+        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
+            if (numSeq > 0 && withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
+            numSeq++;
+            return this;
+        }
+
+        BertTokenizationBuilder addEndTokensIfNecessary() {
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            return this;
+        }
+
+        int[] buildIds() {
+            return tokenIds.build().flatMapToInt(Function.identity()).toArray();
+        }
+
+        int[] buildMap() {
+            return tokenMap.build().flatMapToInt(Function.identity()).toArray();
+        }
+    }
+
     public static class Builder {
     public static class Builder {
 
 
         protected final List<String> originalVocab;
         protected final List<String> originalVocab;
@@ -329,7 +423,7 @@ public class BertTokenizer implements NlpTokenizer {
         protected int maxSequenceLength;
         protected int maxSequenceLength;
         protected Boolean doStripAccents = null;
         protected Boolean doStripAccents = null;
         protected Set<String> neverSplit;
         protected Set<String> neverSplit;
-        protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
+        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
 
 
         protected Builder(List<String> vocab, Tokenization tokenization) {
         protected Builder(List<String> vocab, Tokenization tokenization) {
             this.originalVocab = vocab;
             this.originalVocab = vocab;
@@ -382,7 +476,7 @@ public class BertTokenizer implements NlpTokenizer {
             return this;
             return this;
         }
         }
 
 
-        public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
+        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
             this.requestBuilderFactory = requestBuilderFactory;
             this.requestBuilderFactory = requestBuilderFactory;
             return this;
             return this;
         }
         }

+ 186 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java

@@ -0,0 +1,186 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.MPNetRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.IntStream;
+
+/**
+ * Performs basic tokenization and normalization of input text
+ * then tokenizes with the WordPiece algorithm using the given
+ * vocabulary.
+ */
+public class MPNetTokenizer extends BertTokenizer {
+
+    public static final String UNKNOWN_TOKEN = "[UNK]";
+    public static final String SEPARATOR_TOKEN = "</s>";
+    public static final String PAD_TOKEN = "<pad>";
+    public static final String CLASS_TOKEN = "<s>";
+    public static final String MASK_TOKEN = "<mask>";
+    private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
+
+    protected MPNetTokenizer(
+        List<String> originalVocab,
+        SortedMap<String, Integer> vocab,
+        boolean doLowerCase,
+        boolean doTokenizeCjKChars,
+        boolean doStripAccents,
+        boolean withSpecialTokens,
+        int maxSequenceLength,
+        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+        Set<String> neverSplit
+    ) {
+        super(
+            originalVocab,
+            vocab,
+            doLowerCase,
+            doTokenizeCjKChars,
+            doStripAccents,
+            withSpecialTokens,
+            maxSequenceLength,
+            requestBuilderFactory,
+            Sets.union(neverSplit, NEVER_SPLIT),
+            SEPARATOR_TOKEN,
+            CLASS_TOKEN,
+            PAD_TOKEN,
+            MASK_TOKEN,
+            UNKNOWN_TOKEN
+        );
+    }
+
+    @Override
+    protected int getNumExtraTokensForSeqPair() {
+        return 4;
+    }
+
+    @Override
+    protected BertTokenizationBuilder bertTokenizationBuilder() {
+        return new MPNetTokenizationBuilder();
+    }
+
+    protected class MPNetTokenizationBuilder extends BertTokenizationBuilder {
+
+        @Override
+        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
+            if (numSeq > 0 && withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId, sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION, SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
+            numSeq++;
+            return this;
+        }
+
+    }
+
+    public static Builder mpBuilder(List<String> vocab, Tokenization tokenization) {
+        return new Builder(vocab, tokenization);
+    }
+
+    public static class Builder {
+
+        protected final List<String> originalVocab;
+        protected final SortedMap<String, Integer> vocab;
+        protected boolean doLowerCase = false;
+        protected boolean doTokenizeCjKChars = true;
+        protected boolean withSpecialTokens = true;
+        protected int maxSequenceLength;
+        protected Boolean doStripAccents = null;
+        protected Set<String> neverSplit;
+        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = MPNetRequestBuilder::new;
+
+        protected Builder(List<String> vocab, Tokenization tokenization) {
+            this.originalVocab = vocab;
+            this.vocab = buildSortedVocab(vocab);
+            this.doLowerCase = tokenization.doLowerCase();
+            this.withSpecialTokens = tokenization.withSpecialTokens();
+            this.maxSequenceLength = tokenization.maxSequenceLength();
+        }
+
+        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
+            SortedMap<String, Integer> sortedVocab = new TreeMap<>();
+            for (int i = 0; i < vocab.size(); i++) {
+                sortedVocab.put(vocab.get(i), i);
+            }
+            return sortedVocab;
+        }
+
+        public Builder setDoLowerCase(boolean doLowerCase) {
+            this.doLowerCase = doLowerCase;
+            return this;
+        }
+
+        public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) {
+            this.doTokenizeCjKChars = doTokenizeCjKChars;
+            return this;
+        }
+
+        public Builder setDoStripAccents(Boolean doStripAccents) {
+            this.doStripAccents = doStripAccents;
+            return this;
+        }
+
+        public Builder setNeverSplit(Set<String> neverSplit) {
+            this.neverSplit = neverSplit;
+            return this;
+        }
+
+        public Builder setMaxSequenceLength(int maxSequenceLength) {
+            this.maxSequenceLength = maxSequenceLength;
+            return this;
+        }
+
+        /**
+         * Include CLS and SEP tokens
+         * @param withSpecialTokens if true include CLS and SEP tokens
+         * @return this
+         */
+        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
+            this.withSpecialTokens = withSpecialTokens;
+            return this;
+        }
+
+        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
+            this.requestBuilderFactory = requestBuilderFactory;
+            return this;
+        }
+
+        public MPNetTokenizer build() {
+            // if not set strip accents defaults to the value of doLowerCase
+            if (doStripAccents == null) {
+                doStripAccents = doLowerCase;
+            }
+
+            if (neverSplit == null) {
+                neverSplit = Collections.emptySet();
+            }
+
+            return new MPNetTokenizer(
+                originalVocab,
+                vocab,
+                doLowerCase,
+                doTokenizeCjKChars,
+                doStripAccents,
+                withSpecialTokens,
+                maxSequenceLength,
+                requestBuilderFactory,
+                neverSplit
+            );
+        }
+    }
+}

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -8,9 +8,11 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.MPNetRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 
 
@@ -32,6 +34,8 @@ public interface NlpTokenizer {
 
 
     OptionalInt getPadTokenId();
     OptionalInt getPadTokenId();
 
 
+    String getPadToken();
+
     OptionalInt getMaskTokenId();
     OptionalInt getMaskTokenId();
 
 
     String getMaskToken();
     String getMaskToken();
@@ -42,6 +46,9 @@ public interface NlpTokenizer {
         if (params instanceof BertTokenization) {
         if (params instanceof BertTokenization) {
             return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(BertRequestBuilder::new).build();
             return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(BertRequestBuilder::new).build();
         }
         }
+        if (params instanceof MPNetTokenization) {
+            return MPNetTokenizer.mpBuilder(vocabulary.get(), params).setRequestBuilderFactory(MPNetRequestBuilder::new).build();
+        }
         throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
         throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
     }
     }
 }
 }

+ 105 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilderTests.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizer;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizerTests.TEST_CASED_VOCAB;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.hasSize;
+
+public class MPNetRequestBuilderTests extends ESTestCase {
+
+    public void testBuildRequest() throws IOException {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 512, null)).build();
+
+        MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+        NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(12, 0, 1, 3, 13), firstListItemFromMap("tokens", jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), firstListItemFromMap("arg_1", jsonDocAsMap));
+    }
+
+    @SuppressWarnings("unchecked")
+    private List<Integer> firstListItemFromMap(String name, Map<String, Object> jsonDocAsMap) {
+        return nthListItemFromMap(name, 0, jsonDocAsMap);
+    }
+
+    @SuppressWarnings("unchecked")
+    public static List<Integer> nthListItemFromMap(String name, int n, Map<String, Object> jsonDocAsMap) {
+        return ((List<List<Integer>>) jsonDocAsMap.get(name)).get(n);
+    }
+
+    public void testInputTooLarge() throws IOException {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 5, null)).build();
+        {
+            MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+            ElasticsearchStatusException e = expectThrows(
+                ElasticsearchStatusException.class,
+                () -> requestBuilder.buildRequest(
+                    Collections.singletonList("Elasticsearch fun Elasticsearch fun Elasticsearch fun"),
+                    "request1",
+                    Tokenization.Truncate.NONE
+                )
+            );
+
+            assertThat(
+                e.getMessage(),
+                containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]")
+            );
+        }
+        {
+            MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+            // input will become 3 tokens + the Class and Separator token = 5 which is
+            // our max sequence length
+            requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testBatchWithPadding() throws IOException {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 512, null)).build();
+
+        MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+        NlpTask.Request request = requestBuilder.buildRequest(
+            List.of("Elasticsearch", "my little red car", "Godzilla day"),
+            "request1",
+            Tokenization.Truncate.NONE
+        );
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_1"), hasSize(3));
+
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(12, 0, 1, 13, 19, 19), nthListItemFromMap("tokens", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 19, 19), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(12, 4, 5, 6, 7, 13), nthListItemFromMap("tokens", 1, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 1, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(12, 8, 9, 16, 13, 19), nthListItemFromMap("tokens", 2, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 19), nthListItemFromMap("arg_1", 2, jsonDocAsMap));
+    }
+}

+ 95 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.Matchers.contains;
+
+public class MPNetTokenizerTests extends ESTestCase {
+
+    public static final List<String> TEST_CASED_VOCAB = List.of(
+        "Elastic",
+        "##search",
+        "is",
+        "fun",
+        "my",
+        "little",
+        "red",
+        "car",
+        "God",
+        "##zilla",
+        ".",
+        ",",
+        MPNetTokenizer.CLASS_TOKEN,
+        MPNetTokenizer.SEPARATOR_TOKEN,
+        MPNetTokenizer.MASK_TOKEN,
+        MPNetTokenizer.UNKNOWN_TOKEN,
+        "day",
+        "Pancake",
+        "with",
+        MPNetTokenizer.PAD_TOKEN
+    );
+
+    private List<String> tokenStrings(List<DelimitedToken> tokens) {
+        return tokens.stream().map(DelimitedToken::getToken).collect(Collectors.toList());
+    }
+
+    public void testTokenize() {
+        BertTokenizer tokenizer = MPNetTokenizer.mpBuilder(
+            TEST_CASED_VOCAB,
+            new MPNetTokenization(null, false, null, Tokenization.Truncate.NONE)
+        ).build();
+
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+        assertThat(tokenStrings(tokenization.getTokens()), contains("Elasticsearch", "fun"));
+        assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds());
+        assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+    }
+
+    public void testMultiSeqTokenization() {
+        MPNetTokenizer tokenizer = MPNetTokenizer.mpBuilder(
+            TEST_CASED_VOCAB,
+            new MPNetTokenization(null, false, null, Tokenization.Truncate.NONE)
+        ).setDoLowerCase(false).setWithSpecialTokens(true).build();
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            "Elasticsearch is fun",
+            "Godzilla my little red car",
+            Tokenization.Truncate.NONE
+        );
+
+        var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
+        assertThat(
+            tokenStream,
+            contains(
+                MPNetTokenizer.CLASS_TOKEN,
+                "Elastic",
+                "##search",
+                "is",
+                "fun",
+                MPNetTokenizer.SEPARATOR_TOKEN,
+                MPNetTokenizer.SEPARATOR_TOKEN,
+                "God",
+                "##zilla",
+                "my",
+                "little",
+                "red",
+                "car",
+                MPNetTokenizer.SEPARATOR_TOKEN
+            )
+        );
+        assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.getTokenIds());
+    }
+
+}