Browse Source

[NLP] Support the different mask tokens used by NLP models for Fill Mask (#97453)

Add mask_token field to fill_mask of _ml/trained_models.

This change will enable users and Kibana to get the particular mask tokens needed for deployed models by adding a mask_token field to the GET _ml/trained_models API, as an enhancement to support kibana#159577.
Max Hniebergall 2 years ago
parent
commit
3a4113801c
16 changed files with 275 additions and 8 deletions
  1. 5 1
      docs/reference/ml/ml-shared.asciidoc
  2. 4 1
      docs/reference/ml/trained-models/apis/get-trained-models.asciidoc
  3. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertJapaneseTokenization.java
  4. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenization.java
  5. 31 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java
  6. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenization.java
  7. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenization.java
  8. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java
  9. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/XLMRobertaTokenization.java
  10. 74 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java
  11. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  12. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java
  13. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java
  14. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java
  15. 2 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java
  16. 119 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

+ 5 - 1
docs/reference/ml/ml-shared.asciidoc

@@ -955,7 +955,7 @@ BERT-style tokenization is to be performed with the enclosed settings.
 end::inference-config-nlp-tokenization-bert[]
 
 tag::inference-config-nlp-tokenization-bert-ja[]
-experimental:[] BERT-style tokenization for Japanese text is to be performed 
+experimental:[] BERT-style tokenization for Japanese text is to be performed
 with the enclosed settings.
 end::inference-config-nlp-tokenization-bert-ja[]
 
@@ -1125,6 +1125,10 @@ The field that is added to incoming documents to contain the inference
 prediction. Defaults to `predicted_value`.
 end::inference-config-results-field[]
 
+tag::inference-config-mask-token[]
+The string/token which will be removed from incoming documents and replaced with the inference prediction(s). In a response, this field contains the mask token for the specified model/tokenizer. Each model and tokenizer has a predefined mask token which cannot be changed. Thus, it is recommended not to set this value in requests. However, if this field is present in a request, its value must match the predefined value for that model/tokenizer, otherwise the request will fail.
+end::inference-config-mask-token[]
+
 tag::inference-config-results-field-processor[]
 The field that is added to incoming documents to contain the inference
 prediction. Defaults to the `results_field` value of the {dfanalytics-job} that was

+ 4 - 1
docs/reference/ml/trained-models/apis/get-trained-models.asciidoc

@@ -166,7 +166,6 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field]
 ======
-
 `fill_mask`::::
 (Optional, object)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask]
@@ -174,6 +173,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask]
 .Properties of fill_mask inference
 [%collapsible%open]
 ======
+`mask_token`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-mask-token]
+
 `tokenization`::::
 (Optional, object)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertJapaneseTokenization.java

@@ -20,6 +20,8 @@ public class BertJapaneseTokenization extends Tokenization {
 
     public static final ParseField NAME = new ParseField("bert_ja");
 
+    public static final String MASK_TOKEN = "[MASK]";
+
     public static ConstructingObjectParser<BertJapaneseTokenization, Void> createJpParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<BertJapaneseTokenization, Void> parser = new ConstructingObjectParser<>(
             "bert_japanese_tokenization",
@@ -61,6 +63,11 @@ public class BertJapaneseTokenization extends Tokenization {
         return builder;
     }
 
+    @Override
+    public String getMaskToken() {
+        return MASK_TOKEN;
+    }
+
     @Override
     public String getWriteableName() {
         return BertJapaneseTokenization.NAME.getPreferredName();

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenization.java

@@ -21,6 +21,8 @@ public class BertTokenization extends Tokenization {
 
     public static final ParseField NAME = new ParseField("bert");
 
+    public static final String MASK_TOKEN = "[MASK]";
+
     public static ConstructingObjectParser<BertTokenization, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<BertTokenization, Void> parser = new ConstructingObjectParser<>(
             "bert_tokenization",
@@ -67,6 +69,11 @@ public class BertTokenization extends Tokenization {
         return builder;
     }
 
+    @Override
+    public String getMaskToken() {
+        return MASK_TOKEN;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();

+ 31 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java

@@ -9,15 +9,18 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.Version;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ObjectParser;
+import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -26,6 +29,7 @@ import java.util.Optional;
 public class FillMaskConfig implements NlpConfig {
 
     public static final String NAME = "fill_mask";
+    public static final String MASK_TOKEN = "mask_token";
     public static final int DEFAULT_NUM_RESULTS = 5;
 
     public static FillMaskConfig fromXContentStrict(XContentParser parser) {
@@ -36,6 +40,7 @@ public class FillMaskConfig implements NlpConfig {
         return LENIENT_PARSER.apply(parser, null).build();
     }
 
+    private static final ParseField MASK_TOKEN_FIELD = new ParseField(MASK_TOKEN);
     private static final ObjectParser<FillMaskConfig.Builder, Void> STRICT_PARSER = createParser(false);
     private static final ObjectParser<FillMaskConfig.Builder, Void> LENIENT_PARSER = createParser(true);
 
@@ -57,6 +62,7 @@ public class FillMaskConfig implements NlpConfig {
         );
         parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
         parser.declareString(Builder::setResultsField, RESULTS_FIELD);
+        parser.declareString(Builder::setMaskToken, MASK_TOKEN_FIELD);
         return parser;
     }
 
@@ -101,6 +107,9 @@ public class FillMaskConfig implements NlpConfig {
         if (resultsField != null) {
             builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
         }
+        if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
+            builder.field(MASK_TOKEN_FIELD.getPreferredName(), tokenization.getMaskToken());
+        }
         builder.endObject();
         return builder;
     }
@@ -182,8 +191,9 @@ public class FillMaskConfig implements NlpConfig {
     public static class Builder {
         private VocabularyConfig vocabularyConfig;
         private Tokenization tokenization;
-        private int numTopClasses;
+        private Integer numTopClasses;
         private String resultsField;
+        private String maskToken;
 
         Builder() {}
 
@@ -214,8 +224,27 @@ public class FillMaskConfig implements NlpConfig {
             return this;
         }
 
-        public FillMaskConfig build() {
+        public FillMaskConfig.Builder setMaskToken(String maskToken) {
+            this.maskToken = maskToken;
+            return this;
+        }
+
+        public FillMaskConfig build() throws IllegalArgumentException {
+            if (tokenization == null) {
+                tokenization = Tokenization.createDefault();
+            }
+            validateMaskToken(tokenization.getMaskToken());
             return new FillMaskConfig(vocabularyConfig, tokenization, numTopClasses, resultsField);
         }
+
+        private void validateMaskToken(String tokenizationMaskToken) throws IllegalArgumentException {
+            if (maskToken != null) {
+                if (maskToken.equals(tokenizationMaskToken) == false) {
+                    throw new IllegalArgumentException(
+                        Strings.format("Mask token requested was [%s] but must be [%s] for this model", maskToken, tokenizationMaskToken)
+                    );
+                }
+            }
+        }
     }
 }

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenization.java

@@ -20,6 +20,7 @@ import java.io.IOException;
 public class MPNetTokenization extends Tokenization {
 
     public static final ParseField NAME = new ParseField("mpnet");
+    public static final String MASK_TOKEN = "<mask>";
 
     public static ConstructingObjectParser<MPNetTokenization, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<MPNetTokenization, Void> parser = new ConstructingObjectParser<>(
@@ -67,6 +68,11 @@ public class MPNetTokenization extends Tokenization {
         return builder;
     }
 
+    @Override
+    public String getMaskToken() {
+        return MASK_TOKEN;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenization.java

@@ -23,6 +23,7 @@ import java.util.Optional;
 
 public class RobertaTokenization extends Tokenization {
     public static final String NAME = "roberta";
+    public static final String MASK_TOKEN = "<mask>";
     private static final boolean DEFAULT_ADD_PREFIX_SPACE = false;
 
     private static final ParseField ADD_PREFIX_SPACE = new ParseField("add_prefix_space");
@@ -99,6 +100,11 @@ public class RobertaTokenization extends Tokenization {
         out.writeBoolean(addPrefixSpace);
     }
 
+    @Override
+    public String getMaskToken() {
+        return MASK_TOKEN;
+    }
+
     @Override
     XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         builder.field(ADD_PREFIX_SPACE.getPreferredName(), addPrefixSpace);

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java

@@ -143,6 +143,8 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
         }
     }
 
+    public abstract String getMaskToken();
+
     abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException;
 
     @Override

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/XLMRobertaTokenization.java

@@ -21,6 +21,7 @@ import java.io.IOException;
 
 public class XLMRobertaTokenization extends Tokenization {
     public static final String NAME = "xlm_roberta";
+    public static final String MASK_TOKEN = "<mask>";
 
     public static ConstructingObjectParser<XLMRobertaTokenization, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<XLMRobertaTokenization, Void> parser = new ConstructingObjectParser<>(
@@ -81,6 +82,11 @@ public class XLMRobertaTokenization extends Tokenization {
         super.writeTo(out);
     }
 
+    @Override
+    public String getMaskToken() {
+        return MASK_TOKEN;
+    }
+
     @Override
     XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         return builder;

+ 74 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java

@@ -75,4 +75,78 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
             randomBoolean() ? null : randomAlphaOfLength(5)
         );
     }
+
+    public void testCreateBuilder() {
+
+        VocabularyConfig vocabularyConfig = randomBoolean() ? null : VocabularyConfigTests.createRandom();
+
+        Tokenization tokenization = randomBoolean()
+            ? null
+            : randomFrom(
+                BertTokenizationTests.createRandom(),
+                MPNetTokenizationTests.createRandom(),
+                RobertaTokenizationTests.createRandom()
+            );
+
+        Integer numTopClasses = randomBoolean() ? null : randomInt();
+
+        String resultsField = randomBoolean() ? null : randomAlphaOfLength(5);
+
+        new FillMaskConfig.Builder().setVocabularyConfig(vocabularyConfig)
+            .setTokenization(tokenization)
+            .setNumTopClasses(numTopClasses)
+            .setResultsField(resultsField)
+            .setMaskToken(tokenization == null ? null : tokenization.getMaskToken())
+            .build();
+    }
+
+    public void testCreateBuilderWithException() throws Exception {
+
+        VocabularyConfig vocabularyConfig = randomBoolean() ? null : VocabularyConfigTests.createRandom();
+
+        Tokenization tokenization = randomBoolean()
+            ? null
+            : randomFrom(
+                BertTokenizationTests.createRandom(),
+                MPNetTokenizationTests.createRandom(),
+                RobertaTokenizationTests.createRandom()
+            );
+
+        Integer numTopClasses = randomBoolean() ? null : randomInt();
+
+        String resultsField = randomBoolean() ? null : randomAlphaOfLength(5);
+        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> {
+            FillMaskConfig fmc = new FillMaskConfig.Builder().setVocabularyConfig(vocabularyConfig)
+                .setTokenization(tokenization)
+                .setNumTopClasses(numTopClasses)
+                .setResultsField(resultsField)
+                .setMaskToken("not a real mask token")
+                .build();
+        });
+
+    }
+
+    public void testCreateBuilderWithNullMaskToken() {
+
+        VocabularyConfig vocabularyConfig = randomBoolean() ? null : VocabularyConfigTests.createRandom();
+
+        Tokenization tokenization = randomBoolean()
+            ? null
+            : randomFrom(
+                BertTokenizationTests.createRandom(),
+                MPNetTokenizationTests.createRandom(),
+                RobertaTokenizationTests.createRandom()
+            );
+
+        Integer numTopClasses = randomBoolean() ? null : randomInt();
+
+        String resultsField = randomBoolean() ? null : randomAlphaOfLength(5);
+
+        new FillMaskConfig.Builder().setVocabularyConfig(vocabularyConfig)
+            .setTokenization(tokenization)
+            .setNumTopClasses(numTopClasses)
+            .setResultsField(resultsField)
+            .build();
+    }
+
 }

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 import org.apache.lucene.analysis.TokenStream;
 import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
@@ -39,7 +40,7 @@ public class BertTokenizer extends NlpTokenizer {
     public static final String SEPARATOR_TOKEN = "[SEP]";
     public static final String PAD_TOKEN = "[PAD]";
     public static final String CLASS_TOKEN = "[CLS]";
-    public static final String MASK_TOKEN = "[MASK]";
+    public static final String MASK_TOKEN = BertTokenization.MASK_TOKEN;
 
     private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 
 import java.util.Collections;
@@ -26,7 +27,7 @@ public class MPNetTokenizer extends BertTokenizer {
     public static final String SEPARATOR_TOKEN = "</s>";
     public static final String PAD_TOKEN = "<pad>";
     public static final String CLASS_TOKEN = "<s>";
-    public static final String MASK_TOKEN = "<mask>";
+    public static final String MASK_TOKEN = MPNetTokenization.MASK_TOKEN;
     private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 
     protected MPNetTokenizer(

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java

@@ -32,7 +32,7 @@ public class RobertaTokenizer extends NlpTokenizer {
     public static final String SEPARATOR_TOKEN = "</s>";
     public static final String PAD_TOKEN = "<pad>";
     public static final String CLASS_TOKEN = "<s>";
-    public static final String MASK_TOKEN = "<mask>";
+    public static final String MASK_TOKEN = RobertaTokenization.MASK_TOKEN;
 
     private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java

@@ -34,7 +34,7 @@ public class XLMRobertaTokenizer extends NlpTokenizer {
     public static final String SEPARATOR_TOKEN = "</s>";
     public static final String PAD_TOKEN = "<pad>";
     public static final String CLASS_TOKEN = "<s>";
-    public static final String MASK_TOKEN = "<mask>";
+    public static final String MASK_TOKEN = XLMRobertaTokenization.MASK_TOKEN;
 
     private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.action.util.PageParams;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -135,6 +136,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
             Map<String, String> params = new HashMap<>(channel.request().params());
             defaultToXContentParamValues.forEach((k, v) -> params.computeIfAbsent(k, defaultToXContentParamValues::get));
             includes.forEach(include -> params.put(include, "true"));
+            params.put(ToXContentParams.FOR_INTERNAL_STORAGE, "false");
             response.toXContent(builder, new ToXContent.MapParams(params));
             return new RestResponse(getStatus(response), builder);
         }

+ 119 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -78,6 +78,125 @@ setup:
             "total_parts": 1
           }
 ---
+"Test getting and putting Fill Mask with two mask tokens, as well as exceptions caused by requests with the wrong token":
+  - do:
+      ml.put_trained_model:
+        model_id: "bert_fill_mask_model"
+        body: >
+          {
+            "description": "simple model for testing",
+            "model_type": "pytorch",
+            "inference_config": {
+              "fill_mask": {
+                "tokenization":{
+                  "bert": {
+                    "with_special_tokens": false
+                  }
+                }
+              }
+            }
+          }
+
+  - do:
+      ml.put_trained_model:
+        model_id: "roberta_fill_mask_model"
+        body: >
+          {
+            "description": "simple model for testing",
+            "model_type": "pytorch",
+            "inference_config": {
+              "fill_mask": {
+                "tokenization":{
+                  "roberta": {
+                    "with_special_tokens": false
+                  }
+                }
+              }
+            }
+          }
+  - do:
+      ml.put_trained_model:
+        model_id: "with_correct_mask_token"
+        body: >
+          {
+            "description": "simple model for testing",
+            "model_type": "pytorch",
+            "inference_config": {
+              "fill_mask": {
+                "tokenization":{
+                  "bert": {
+                    "with_special_tokens": false
+                  }
+                },
+                "mask_token": "[MASK]"
+              }
+            }
+          }
+  - do:
+      ml.put_trained_model:
+        model_id: "with_other_correct_mask_token"
+        body: >
+          {
+            "description": "simple model for testing",
+            "model_type": "pytorch",
+            "inference_config": {
+              "fill_mask": {
+                "tokenization":{
+                  "roberta": {
+                    "with_special_tokens": false
+                  }
+                },
+                "mask_token": "<mask>"
+              }
+            }
+          }
+  - do:
+      ml.get_trained_models:
+        model_id: "bert_fill_mask_model"
+  - match: {trained_model_configs.0.inference_config.fill_mask.mask_token: "[MASK]"}
+  - do:
+      ml.get_trained_models:
+        model_id: "roberta_fill_mask_model"
+  - match: {trained_model_configs.0.inference_config.fill_mask.mask_token: "<mask>"}
+  - do:
+      catch: /IllegalArgumentException. Mask token requested was \[<mask>\] but must be \[\[MASK\]\] for this model/
+      ml.put_trained_model:
+        model_id: "incorrect_mask_token"
+        body: >
+          {
+            "description": "simple model for testing",
+            "model_type": "pytorch",
+            "inference_config": {
+              "fill_mask": {
+                "tokenization":{
+                  "bert": {
+                    "with_special_tokens": false
+                  }
+                },
+                "mask_token": "<mask>"
+              }
+            }
+          }
+  - do:
+      catch: /IllegalArgumentException. Mask token requested was \[\[MASK\]\] but must be \[<mask>\] for this model/
+      ml.put_trained_model:
+        model_id: "incorrect_mask_token"
+        body: >
+          {
+            "description": "simple model for testing",
+            "model_type": "pytorch",
+            "inference_config": {
+              "fill_mask": {
+                "tokenization":{
+                  "roberta": {
+                    "with_special_tokens": false
+                  }
+                },
+                "mask_token": "[MASK]"
+              }
+            }
+          }
+---
 "Test start deployment fails with missing model definition":
 
   - do: