فهرست منبع

[ML] add support for distilbert pytorch models (#76679)

This commit adds support for distilbert pytorch models.

While the tokenization itself is exactly the same as bert, the parameters sent to the model are different. 

DistilBERT does not require the segment mask or positional IDs to be sent. Only the input mask and token ids. 

But, since the effective output of the tokenization sent to the model is different, I opted to consider it as a unique
tokenizer, inheriting from our bert implementation.

The API now looks like:
for BERT models
```js
"inference_config": {
  "ner": {
    "vocabulary": {/*...*/},
    "tokenization": {
      "bert": {/*...*/}
    }
  }
}
```
For DistilBERT models
```js
"inference_config": {
  "ner": {
    "vocabulary": {/*...*/},
    "tokenization": {
      "distil_bert": {/*...*/}
    }
  }
}
```
Benjamin Trent 4 سال پیش
والد
کامیت
f913aaef5b
36فایلهای تغییر یافته به همراه820 افزوده شده و 269 حذف شده
  1. 35 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  2. 16 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertPassThroughConfig.java
  3. 78 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenization.java
  4. 82 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenization.java
  5. 16 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java
  6. 16 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java
  7. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java
  8. 16 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SentimentAnalysisConfig.java
  9. 23 19
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java
  10. 36 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java
  11. 6 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertPassThroughConfigTests.java
  12. 9 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java
  13. 54 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenizationTests.java
  14. 6 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java
  15. 5 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java
  16. 5 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SentimentAnalysisConfigTests.java
  17. 2 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  18. 2 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java
  19. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  20. 3 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  21. 52 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilder.java
  22. 8 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  23. 9 9
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  24. 11 9
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  25. 8 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  26. 6 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessor.java
  27. 6 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java
  28. 72 72
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  29. 39 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  30. 63 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java
  31. 10 6
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java
  32. 64 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilderTests.java
  33. 4 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  34. 20 10
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  35. 4 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessorTests.java
  36. 30 23
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

+ 35 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -27,7 +27,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResul
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.SentimentAnalysisResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertPassThroughConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
@@ -46,6 +48,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisC
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
@@ -189,6 +192,22 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             LangIdentNeuralNetwork.NAME,
             LangIdentNeuralNetwork::fromXContentLenient));
 
+        // Tokenization
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                Tokenization.class,
+                BertTokenization.NAME,
+                (p, c) -> BertTokenization.fromXContent(p, (boolean) c)
+            )
+        );
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                Tokenization.class,
+                DistilBertTokenization.NAME,
+                (p, c) -> DistilBertTokenization.fromXContent(p, (boolean) c)
+            )
+        );
+
         return namedXContent;
     }
 
@@ -280,6 +299,22 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModelLocation.class,
             IndexLocation.INDEX.getPreferredName(), IndexLocation::new));
 
+        // Tokenization
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                Tokenization.class,
+                BertTokenization.NAME.getPreferredName(),
+                BertTokenization::new
+            )
+        );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                Tokenization.class,
+                DistilBertTokenization.NAME.getPreferredName(),
+                DistilBertTokenization::new
+            )
+        );
+
         return namedWriteables;
     }
 }

+ 16 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertPassThroughConfig.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -36,31 +37,33 @@ public class BertPassThroughConfig implements NlpConfig {
 
     private static ConstructingObjectParser<BertPassThroughConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<BertPassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new BertPassThroughConfig((VocabularyConfig) a[0], (TokenizationParams) a[1]));
+            a -> new BertPassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
         parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
-        parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), TokenizationParams.createParser(ignoreUnknownFields),
-            TOKENIZATION_PARAMS);
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+            TOKENIZATION
+        );
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
-    private final TokenizationParams tokenizationParams;
+    private final Tokenization tokenization;
 
-    public BertPassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable TokenizationParams tokenizationParams) {
+    public BertPassThroughConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
         this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
-        this.tokenizationParams = tokenizationParams == null ? TokenizationParams.createDefault() : tokenizationParams;
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
     }
 
     public BertPassThroughConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
-        tokenizationParams = new TokenizationParams(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
-        builder.field(TOKENIZATION_PARAMS.getPreferredName(), tokenizationParams);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
         builder.endObject();
         return builder;
     }
@@ -73,7 +76,7 @@ public class BertPassThroughConfig implements NlpConfig {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
-        tokenizationParams.writeTo(out);
+        out.writeNamedWriteable(tokenization);
     }
 
     @Override
@@ -103,12 +106,12 @@ public class BertPassThroughConfig implements NlpConfig {
 
         BertPassThroughConfig that = (BertPassThroughConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenizationParams, that.tokenizationParams);
+            && Objects.equals(tokenization, that.tokenization);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenizationParams);
+        return Objects.hash(vocabularyConfig, tokenization);
     }
 
     @Override
@@ -117,7 +120,7 @@ public class BertPassThroughConfig implements NlpConfig {
     }
 
     @Override
-    public TokenizationParams getTokenizationParams() {
-        return tokenizationParams;
+    public Tokenization getTokenization() {
+        return tokenization;
     }
 }

+ 78 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenization.java

@@ -0,0 +1,78 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+
+import java.io.IOException;
+
+public class BertTokenization extends Tokenization {
+
+    public static final ParseField NAME = new ParseField("bert");
+
+    public static ConstructingObjectParser<BertTokenization, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<BertTokenization, Void> parser = new ConstructingObjectParser<>(
+            "bert_tokenization",
+            ignoreUnknownFields,
+            a -> new BertTokenization((Boolean) a[0], (Boolean) a[1], (Integer) a[2])
+        );
+        Tokenization.declareCommonFields(parser);
+        return parser;
+    }
+
+    private static final ConstructingObjectParser<BertTokenization, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<BertTokenization, Void> STRICT_PARSER = createParser(false);
+
+    public static BertTokenization fromXContent(XContentParser parser, boolean lenient) {
+        return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
+    }
+
+    public BertTokenization(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength) {
+        super(doLowerCase, withSpecialTokens, maxSequenceLength);
+    }
+
+    public BertTokenization(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+    }
+
+    XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == null || getClass() != o.getClass()) return false;
+        return super.equals(o);
+    }
+
+    @Override
+    public int hashCode() {
+        return super.hashCode();
+    }
+}

+ 82 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenization.java

@@ -0,0 +1,82 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+
+import java.io.IOException;
+
+public class DistilBertTokenization extends Tokenization {
+
+    public static final ParseField NAME = new ParseField("distil_bert");
+
+    public static ConstructingObjectParser<DistilBertTokenization, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<DistilBertTokenization, Void> parser = new ConstructingObjectParser<>(
+            "distil_bert_tokenization",
+            ignoreUnknownFields,
+            a -> new DistilBertTokenization((Boolean) a[0], (Boolean) a[1], (Integer) a[2])
+        );
+        Tokenization.declareCommonFields(parser);
+        return parser;
+    }
+
+    private static final ConstructingObjectParser<DistilBertTokenization, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<DistilBertTokenization, Void> STRICT_PARSER = createParser(false);
+
+    public static DistilBertTokenization fromXContent(XContentParser parser, boolean lenient) {
+        return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
+    }
+
+    public DistilBertTokenization(
+        @Nullable Boolean doLowerCase,
+        @Nullable Boolean withSpecialTokens,
+        @Nullable Integer maxSequenceLength
+    ) {
+        super(doLowerCase, withSpecialTokens, maxSequenceLength);
+    }
+
+    public DistilBertTokenization(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+    }
+
+    XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == null || getClass() != o.getClass()) return false;
+        return super.equals(o);
+    }
+
+    @Override
+    public int hashCode() {
+        return super.hashCode();
+    }
+}

+ 16 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -36,31 +37,33 @@ public class FillMaskConfig implements NlpConfig {
 
     private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new FillMaskConfig((VocabularyConfig) a[0], (TokenizationParams) a[1]));
+            a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
         parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
-        parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), TokenizationParams.createParser(ignoreUnknownFields),
-            TOKENIZATION_PARAMS);
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+                TOKENIZATION
+        );
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
-    private final TokenizationParams tokenizationParams;
+    private final Tokenization tokenization;
 
-    public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable TokenizationParams tokenizationParams) {
+    public FillMaskConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
         this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
-        this.tokenizationParams = tokenizationParams == null ? TokenizationParams.createDefault() : tokenizationParams;
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
     }
 
     public FillMaskConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
-        tokenizationParams = new TokenizationParams(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
-        builder.field(TOKENIZATION_PARAMS.getPreferredName(), tokenizationParams);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
         builder.endObject();
         return builder;
     }
@@ -73,7 +76,7 @@ public class FillMaskConfig implements NlpConfig {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
-        tokenizationParams.writeTo(out);
+        out.writeNamedWriteable(tokenization);
     }
 
     @Override
@@ -98,12 +101,12 @@ public class FillMaskConfig implements NlpConfig {
 
         FillMaskConfig that = (FillMaskConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenizationParams, that.tokenizationParams);
+            && Objects.equals(tokenization, that.tokenization);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenizationParams);
+        return Objects.hash(vocabularyConfig, tokenization);
     }
 
     @Override
@@ -112,8 +115,8 @@ public class FillMaskConfig implements NlpConfig {
     }
 
     @Override
-    public TokenizationParams getTokenizationParams() {
-        return tokenizationParams;
+    public Tokenization getTokenization() {
+        return tokenization;
     }
 
     @Override

+ 16 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -39,36 +40,38 @@ public class NerConfig implements NlpConfig {
     @SuppressWarnings({ "unchecked"})
     private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new NerConfig((VocabularyConfig) a[0], (TokenizationParams) a[1], (List<String>) a[2]));
+            a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
         parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
-        parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), TokenizationParams.createParser(ignoreUnknownFields),
-            TOKENIZATION_PARAMS);
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+                TOKENIZATION
+        );
         parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
-    private final TokenizationParams tokenizationParams;
+    private final Tokenization tokenization;
     private final List<String> classificationLabels;
 
     public NerConfig(VocabularyConfig vocabularyConfig,
-                     @Nullable TokenizationParams tokenizationParams,
+                     @Nullable Tokenization tokenization,
                      @Nullable List<String> classificationLabels) {
         this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
-        this.tokenizationParams = tokenizationParams == null ? TokenizationParams.createDefault() : tokenizationParams;
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
         this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
     }
 
     public NerConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
-        tokenizationParams = new TokenizationParams(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
         classificationLabels = in.readStringList();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
-        tokenizationParams.writeTo(out);
+        out.writeNamedWriteable(tokenization);
         out.writeStringCollection(classificationLabels);
     }
 
@@ -76,7 +79,7 @@ public class NerConfig implements NlpConfig {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
-        builder.field(TOKENIZATION_PARAMS.getPreferredName(), tokenizationParams);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
         if (classificationLabels.isEmpty() == false) {
             builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
         }
@@ -111,13 +114,13 @@ public class NerConfig implements NlpConfig {
 
         NerConfig that = (NerConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenizationParams, that.tokenizationParams)
+            && Objects.equals(tokenization, that.tokenization)
             && Objects.equals(classificationLabels, that.classificationLabels);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenizationParams, classificationLabels);
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels);
     }
 
     @Override
@@ -126,8 +129,8 @@ public class NerConfig implements NlpConfig {
     }
 
     @Override
-    public TokenizationParams getTokenizationParams() {
-        return tokenizationParams;
+    public Tokenization getTokenization() {
+        return tokenization;
     }
 
     public List<String> getClassificationLabels() {

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java

@@ -12,7 +12,7 @@ import org.elasticsearch.common.xcontent.ParseField;
 public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig {
 
     ParseField VOCABULARY = new ParseField("vocabulary");
-    ParseField TOKENIZATION_PARAMS = new ParseField("tokenization_params");
+    ParseField TOKENIZATION = new ParseField("tokenization");
     ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
 
     /**
@@ -23,5 +23,5 @@ public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParse
     /**
      * @return the model tokenization parameters
      */
-    TokenizationParams getTokenizationParams();
+    Tokenization getTokenization();
 }

+ 16 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SentimentAnalysisConfig.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -39,35 +40,37 @@ public class SentimentAnalysisConfig implements NlpConfig {
     @SuppressWarnings({ "unchecked"})
     private static ConstructingObjectParser<SentimentAnalysisConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<SentimentAnalysisConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new SentimentAnalysisConfig((VocabularyConfig) a[0], (TokenizationParams) a[1], (List<String>) a[2]));
+            a -> new SentimentAnalysisConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
         parser.declareObject(ConstructingObjectParser.constructorArg(), VocabularyConfig.createParser(ignoreUnknownFields), VOCABULARY);
-        parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), TokenizationParams.createParser(ignoreUnknownFields),
-            TOKENIZATION_PARAMS);
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+                TOKENIZATION
+        );
         parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
-    private final TokenizationParams tokenizationParams;
+    private final Tokenization tokenization;
     private final List<String> classificationLabels;
 
-    public SentimentAnalysisConfig(VocabularyConfig vocabularyConfig, @Nullable TokenizationParams tokenizationParams,
+    public SentimentAnalysisConfig(VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization,
                                    @Nullable List<String> classificationLabels) {
         this.vocabularyConfig = ExceptionsHelper.requireNonNull(vocabularyConfig, VOCABULARY);
-        this.tokenizationParams = tokenizationParams == null ? TokenizationParams.createDefault() : tokenizationParams;
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
         this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
     }
 
     public SentimentAnalysisConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
-        tokenizationParams = new TokenizationParams(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
         classificationLabels = in.readStringList();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
-        tokenizationParams.writeTo(out);
+        out.writeNamedWriteable(tokenization);
         out.writeStringCollection(classificationLabels);
     }
 
@@ -75,7 +78,7 @@ public class SentimentAnalysisConfig implements NlpConfig {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig);
-        builder.field(TOKENIZATION_PARAMS.getPreferredName(), tokenizationParams);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
         if (classificationLabels.isEmpty() == false) {
             builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
         }
@@ -110,13 +113,13 @@ public class SentimentAnalysisConfig implements NlpConfig {
 
         SentimentAnalysisConfig that = (SentimentAnalysisConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenizationParams, that.tokenizationParams)
+            && Objects.equals(tokenization, that.tokenization)
             && Objects.equals(classificationLabels, that.classificationLabels);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenizationParams, classificationLabels);
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels);
     }
 
     @Override
@@ -125,8 +128,8 @@ public class SentimentAnalysisConfig implements NlpConfig {
     }
 
     @Override
-    public TokenizationParams getTokenizationParams() {
-        return tokenizationParams;
+    public Tokenization getTokenization() {
+        return tokenization;
     }
 
     public List<String> getClassificationLabels() {

+ 23 - 19
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TokenizationParams.java → x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java

@@ -7,53 +7,54 @@
 
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
+import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ParseField;
-import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
 import java.io.IOException;
 import java.util.Objects;
+import java.util.Optional;
 
-public class TokenizationParams implements ToXContentObject, Writeable {
+public abstract class Tokenization implements NamedXContentObject, NamedWriteable {
 
+    //TODO add global params like never_split, bos_token, eos_token, mask_token, tokenize_chinese_chars, strip_accents, etc.
     public static final ParseField DO_LOWER_CASE = new ParseField("do_lower_case");
     public static final ParseField WITH_SPECIAL_TOKENS = new ParseField("with_special_tokens");
     public static final ParseField MAX_SEQUENCE_LENGTH = new ParseField("max_sequence_length");
 
     private static final int DEFAULT_MAX_SEQUENCE_LENGTH = 512;
+    private static final boolean DEFAULT_DO_LOWER_CASE = false;
+    private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true;
 
-    public static ConstructingObjectParser<TokenizationParams, Void> createParser(boolean ignoreUnknownFields) {
-        ConstructingObjectParser<TokenizationParams, Void> parser = new ConstructingObjectParser<>("tokenization_params",
-            ignoreUnknownFields, a -> new TokenizationParams((Boolean) a[0], (Boolean) a[1], (Integer) a[2]));
+    static <T extends Tokenization> void declareCommonFields(ConstructingObjectParser<T, ?> parser) {
         parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE);
         parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), WITH_SPECIAL_TOKENS);
         parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_SEQUENCE_LENGTH);
-        return parser;
     }
 
-    private final boolean doLowerCase;
-    private final boolean withSpecialTokens;
-    private final int maxSequenceLength;
-
-    public static TokenizationParams createDefault() {
-        return new TokenizationParams(null, null, null);
+    public static BertTokenization createDefault() {
+        return new BertTokenization(null, null, null);
     }
 
-    public TokenizationParams(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength) {
+    protected final boolean doLowerCase;
+    protected final boolean withSpecialTokens;
+    protected final int maxSequenceLength;
+
+    Tokenization(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength) {
         if (maxSequenceLength != null && maxSequenceLength <= 0) {
             throw new IllegalArgumentException("[" + MAX_SEQUENCE_LENGTH.getPreferredName() + "] must be positive");
         }
-        this.doLowerCase = doLowerCase == null ? false : doLowerCase;
-        this.withSpecialTokens = withSpecialTokens == null ? true : withSpecialTokens;
-        this.maxSequenceLength = maxSequenceLength == null ? DEFAULT_MAX_SEQUENCE_LENGTH : maxSequenceLength;
+        this.doLowerCase = Optional.ofNullable(doLowerCase).orElse(DEFAULT_DO_LOWER_CASE);
+        this.withSpecialTokens = Optional.ofNullable(withSpecialTokens).orElse(DEFAULT_WITH_SPECIAL_TOKENS);
+        this.maxSequenceLength = Optional.ofNullable(maxSequenceLength).orElse(DEFAULT_MAX_SEQUENCE_LENGTH);
     }
 
-    public TokenizationParams(StreamInput in) throws IOException {
+    public Tokenization(StreamInput in) throws IOException {
         this.doLowerCase = in.readBoolean();
         this.withSpecialTokens = in.readBoolean();
         this.maxSequenceLength = in.readVInt();
@@ -66,12 +67,15 @@ public class TokenizationParams implements ToXContentObject, Writeable {
         out.writeVInt(maxSequenceLength);
     }
 
+    abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException;
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(DO_LOWER_CASE.getPreferredName(), doLowerCase);
         builder.field(WITH_SPECIAL_TOKENS.getPreferredName(), withSpecialTokens);
         builder.field(MAX_SEQUENCE_LENGTH.getPreferredName(), maxSequenceLength);
+        builder = doXContentBody(builder, params);
         builder.endObject();
         return builder;
     }
@@ -80,7 +84,7 @@ public class TokenizationParams implements ToXContentObject, Writeable {
     public boolean equals(Object o) {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
-        TokenizationParams that = (TokenizationParams) o;
+        Tokenization that = (Tokenization) o;
         return doLowerCase == that.doLowerCase
             && withSpecialTokens == that.withSpecialTokens
             && maxSequenceLength == that.maxSequenceLength;

+ 36 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceConfigItemTestCase.java

@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+public abstract class InferenceConfigItemTestCase<T extends Writeable & ToXContent> extends AbstractBWCSerializationTestCase<T> {
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+}

+ 6 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertPassThroughConfigTests.java

@@ -10,12 +10,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 import org.junit.Before;
 
 import java.io.IOException;
 
-public class BertPassThroughConfigTests extends AbstractBWCSerializationTestCase<BertPassThroughConfig> {
+public class BertPassThroughConfigTests extends InferenceConfigItemTestCase<BertPassThroughConfig> {
 
     private boolean lenient;
 
@@ -47,7 +47,9 @@ public class BertPassThroughConfigTests extends AbstractBWCSerializationTestCase
     public static BertPassThroughConfig createRandom() {
         return new BertPassThroughConfig(
             VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : TokenizationParamsTests.createRandom()
-        );
+            randomBoolean() ?
+                null :
+                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
+            );
     }
 }

+ 9 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TokenizationParamsTests.java → x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationTests.java

@@ -15,7 +15,7 @@ import org.junit.Before;
 
 import java.io.IOException;
 
-public class TokenizationParamsTests extends AbstractBWCSerializationTestCase<TokenizationParams> {
+public class BertTokenizationTests extends AbstractBWCSerializationTestCase<BertTokenization> {
 
     private boolean lenient;
 
@@ -25,27 +25,27 @@ public class TokenizationParamsTests extends AbstractBWCSerializationTestCase<To
     }
 
     @Override
-    protected TokenizationParams doParseInstance(XContentParser parser) throws IOException {
-        return TokenizationParams.createParser(lenient).apply(parser, null);
+    protected BertTokenization doParseInstance(XContentParser parser) throws IOException {
+        return BertTokenization.createParser(lenient).apply(parser, null);
     }
 
     @Override
-    protected Writeable.Reader<TokenizationParams> instanceReader() {
-        return TokenizationParams::new;
+    protected Writeable.Reader<BertTokenization> instanceReader() {
+        return BertTokenization::new;
     }
 
     @Override
-    protected TokenizationParams createTestInstance() {
+    protected BertTokenization createTestInstance() {
         return createRandom();
     }
 
     @Override
-    protected TokenizationParams mutateInstanceForVersion(TokenizationParams instance, Version version) {
+    protected BertTokenization mutateInstanceForVersion(BertTokenization instance, Version version) {
         return instance;
     }
 
-    public static TokenizationParams createRandom() {
-        return new TokenizationParams(
+    public static BertTokenization createRandom() {
+        return new BertTokenization(
             randomBoolean() ? null : randomBoolean(),
             randomBoolean() ? null : randomBoolean(),
             randomBoolean() ? null : randomIntBetween(1, 1024)

+ 54 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenizationTests.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+
+public class DistilBertTokenizationTests extends AbstractBWCSerializationTestCase<DistilBertTokenization> {
+
+    private boolean lenient;
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected DistilBertTokenization doParseInstance(XContentParser parser) throws IOException {
+        return DistilBertTokenization.createParser(lenient).apply(parser, null);
+    }
+
+    @Override
+    protected Writeable.Reader<DistilBertTokenization> instanceReader() {
+        return DistilBertTokenization::new;
+    }
+
+    @Override
+    protected DistilBertTokenization createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected DistilBertTokenization mutateInstanceForVersion(DistilBertTokenization instance, Version version) {
+        return instance;
+    }
+
+    public static DistilBertTokenization createRandom() {
+        return new DistilBertTokenization(
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomIntBetween(1, 1024)
+        );
+    }
+}

+ 6 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java

@@ -10,12 +10,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 import org.junit.Before;
 
 import java.io.IOException;
 
-public class FillMaskConfigTests extends AbstractBWCSerializationTestCase<FillMaskConfig> {
+public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskConfig> {
 
     private boolean lenient;
 
@@ -47,7 +47,10 @@ public class FillMaskConfigTests extends AbstractBWCSerializationTestCase<FillMa
     public static FillMaskConfig createRandom() {
         return new FillMaskConfig(
             VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : TokenizationParamsTests.createRandom()
+            randomBoolean() ?
+                null :
+                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
+
         );
     }
 }

+ 5 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java

@@ -10,12 +10,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 import org.junit.Before;
 
 import java.io.IOException;
 
-public class NerConfigTests extends AbstractBWCSerializationTestCase<NerConfig> {
+public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
 
     private boolean lenient;
 
@@ -47,7 +47,9 @@ public class NerConfigTests extends AbstractBWCSerializationTestCase<NerConfig>
     public static NerConfig createRandom() {
         return new NerConfig(
             VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : TokenizationParamsTests.createRandom(),
+            randomBoolean() ?
+                null :
+                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10))
         );
     }

+ 5 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SentimentAnalysisConfigTests.java

@@ -10,12 +10,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 import org.junit.Before;
 
 import java.io.IOException;
 
-public class SentimentAnalysisConfigTests extends AbstractBWCSerializationTestCase<SentimentAnalysisConfig> {
+public class SentimentAnalysisConfigTests extends InferenceConfigItemTestCase<SentimentAnalysisConfig> {
 
     private boolean lenient;
 
@@ -47,7 +47,9 @@ public class SentimentAnalysisConfigTests extends AbstractBWCSerializationTestCa
     public static SentimentAnalysisConfig createRandom() {
         return new SentimentAnalysisConfig(
             VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : TokenizationParamsTests.createRandom(),
+            randomBoolean() ?
+                null :
+                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10))
         );
     }

+ 2 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -383,8 +383,8 @@ public class PyTorchModelIT extends ESRestTestCase {
             "              \"index\": \"" + VOCAB_INDEX + "\",\n" +
             "              \"id\": \"test_vocab\"\n" +
             "            },\n" +
-            "            \"tokenization_params\": {" +
-            "              \"with_special_tokens\": false\n" +
+            "            \"tokenization\": {" +
+            "              \"bert\": {\"with_special_tokens\": false}\n" +
             "            }\n" +
             "        }\n" +
             "    },\n" +

+ 2 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

@@ -29,9 +29,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertPassThroughConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationParams;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
@@ -241,7 +241,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
                         .setInferenceConfig(
                             new BertPassThroughConfig(
                                 new VocabularyConfig(indexname, TRAINED_MODEL_ID + "_vocab"),
-                                new TokenizationParams(null, false, null)
+                                new BertTokenization(null, false, null)
                             )
                         )
                         .setLocation(new IndexLocation(indexname))

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -41,7 +41,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
@@ -255,7 +255,7 @@ public class DeploymentManager {
 
     private void waitForResult(ProcessContext processContext,
                                PyTorchResultProcessor.PendingResult pendingResult,
-                               BertTokenizer.TokenizationResult tokenization,
+                               TokenizationResult tokenization,
                                String requestId,
                                TimeValue timeout,
                                NlpTask.ResultProcessor inferenceResultsProcessor,

+ 3 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -10,8 +10,8 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -25,21 +25,14 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     static final String ARG3 = "arg_3";
 
     private final BertTokenizer tokenizer;
-    private final int maxSequenceLength;
 
-    public BertRequestBuilder(BertTokenizer tokenizer, int maxSequenceLength) {
+    public BertRequestBuilder(BertTokenizer tokenizer) {
         this.tokenizer = tokenizer;
-        this.maxSequenceLength = maxSequenceLength;
     }
 
     @Override
     public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
-        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize(input);
-        if (tokenization.getTokenIds().length > maxSequenceLength) {
-            throw ExceptionsHelper.badRequestException(
-                "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
-                tokenization.getTokenIds().length, maxSequenceLength);
-        }
+        TokenizationResult tokenization = tokenizer.tokenize(input);
         return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
     }
 

+ 52 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilder.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+public class DistilBertRequestBuilder implements NlpTask.RequestBuilder {
+
+    static final String REQUEST_ID = "request_id";
+    static final String TOKENS = "tokens";
+    static final String ARG1 = "arg_1";
+
+    private final BertTokenizer tokenizer;
+
+    public DistilBertRequestBuilder(BertTokenizer tokenizer) {
+        this.tokenizer = tokenizer;
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
+        TokenizationResult result = tokenizer.tokenize(input);
+        return new NlpTask.Request(result, jsonRequest(result.getTokenIds(), requestId));
+    }
+
+    static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
+        XContentBuilder builder = XContentFactory.jsonBuilder();
+        builder.startObject();
+        builder.field(REQUEST_ID, requestId);
+        builder.array(TOKENS, tokens);
+
+        int[] inputMask = new int[tokens.length];
+        Arrays.fill(inputMask, 1);
+
+        builder.array(ARG1, inputMask);
+        builder.endObject();
+
+        // BytesReference.bytes closes the builder
+        return BytesReference.bytes(builder);
+    }
+}

+ 8 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -12,6 +12,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -21,10 +23,10 @@ public class FillMaskProcessor implements NlpTask.Processor {
 
     private static final int NUM_RESULTS = 5;
 
-    private final BertRequestBuilder bertRequestBuilder;
+    private final NlpTask.RequestBuilder requestBuilder;
 
-    FillMaskProcessor(BertTokenizer tokenizer, FillMaskConfig config) {
-        this.bertRequestBuilder = new BertRequestBuilder(tokenizer, config.getTokenizationParams().maxSequenceLength());
+    FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
+        this.requestBuilder = tokenizer.requestBuilder();
     }
 
     @Override
@@ -46,15 +48,15 @@ public class FillMaskProcessor implements NlpTask.Processor {
 
     @Override
     public NlpTask.RequestBuilder getRequestBuilder() {
-        return bertRequestBuilder;
+        return requestBuilder;
     }
 
     @Override
     public NlpTask.ResultProcessor getResultProcessor() {
-        return (tokenization, pyTorchResult) -> processResult(tokenization, pyTorchResult);
+        return this::processResult;
     }
 
-    InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+    InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
 
         if (tokenization.getTokens().isEmpty()) {
             return new FillMaskResults(Collections.emptyList());

+ 9 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -14,7 +14,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -67,13 +68,13 @@ public class NerProcessor implements NlpTask.Processor {
         }
     }
 
-    private final BertRequestBuilder bertRequestBuilder;
+    private final NlpTask.RequestBuilder requestBuilder;
     private final IobTag[] iobMap;
 
-    NerProcessor(BertTokenizer tokenizer, NerConfig config) {
-        this.bertRequestBuilder = new BertRequestBuilder(tokenizer, config.getTokenizationParams().maxSequenceLength());
+    NerProcessor(NlpTokenizer tokenizer, NerConfig config) {
         validate(config.getClassificationLabels());
-        iobMap = buildIobMap(config.getClassificationLabels());
+        this.iobMap = buildIobMap(config.getClassificationLabels());
+        this.requestBuilder = tokenizer.requestBuilder();
     }
 
     /**
@@ -124,7 +125,7 @@ public class NerProcessor implements NlpTask.Processor {
 
     @Override
     public NlpTask.RequestBuilder getRequestBuilder() {
-        return bertRequestBuilder;
+        return requestBuilder;
     }
 
     @Override
@@ -133,7 +134,6 @@ public class NerProcessor implements NlpTask.Processor {
     }
 
     static class NerResultProcessor implements NlpTask.ResultProcessor {
-
         private final IobTag[] iobMap;
 
         NerResultProcessor(IobTag[] iobMap) {
@@ -141,7 +141,7 @@ public class NerProcessor implements NlpTask.Processor {
         }
 
         @Override
-        public InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+        public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
             if (tokenization.getTokens().isEmpty()) {
                 return new NerResults(Collections.emptyList());
             }
@@ -163,7 +163,7 @@ public class NerProcessor implements NlpTask.Processor {
          * in the original input replacing them with a single token that
          * gets labelled based on the average score of all its sub-tokens.
          */
-        private List<TaggedToken> tagTokens(BertTokenizer.TokenizationResult tokenization, double[][] scores) {
+        private List<TaggedToken> tagTokens(TokenizationResult tokenization, double[][] scores) {
             List<TaggedToken> taggedTokens = new ArrayList<>();
             int startTokenIndex = 0;
             while (startTokenIndex < tokenization.getTokens().size()) {

+ 11 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -15,7 +15,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
 import java.util.Map;
@@ -24,14 +25,11 @@ import java.util.Objects;
 public class NlpTask {
 
     private final NlpConfig config;
-    private final BertTokenizer tokenizer;
+    private final NlpTokenizer tokenizer;
 
     public NlpTask(NlpConfig config, Vocabulary vocabulary) {
         this.config = config;
-        this.tokenizer = BertTokenizer.builder(vocabulary.get())
-            .setWithSpecialTokens(config.getTokenizationParams().withSpecialTokens())
-            .setDoLowerCase(config.getTokenizationParams().doLowerCase())
-            .build();
+        this.tokenizer = NlpTokenizer.build(vocabulary, config.getTokenization());
     }
 
     /**
@@ -48,7 +46,11 @@ public class NlpTask {
     }
 
     public interface ResultProcessor {
-        InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult);
+        InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult);
+    }
+
+    public interface ResultProcessorFactory {
+        ResultProcessor build(TokenizationResult tokenizationResult);
     }
 
     public interface Processor {
@@ -78,10 +80,10 @@ public class NlpTask {
     }
 
     public static class Request {
-        public final BertTokenizer.TokenizationResult tokenization;
+        public final TokenizationResult tokenization;
         public final BytesReference processInput;
 
-        public Request(BertTokenizer.TokenizationResult tokenization, BytesReference processInput) {
+        public Request(TokenizationResult tokenization, BytesReference processInput) {
             this.tokenization = Objects.requireNonNull(tokenization);
             this.processInput = Objects.requireNonNull(processInput);
         }

+ 8 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -11,7 +11,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertPassThroughConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 /**
  * A NLP processor that directly returns the PyTorch result
@@ -19,10 +20,10 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
  */
 public class PassThroughProcessor implements NlpTask.Processor {
 
-    private final BertRequestBuilder bertRequestBuilder;
+    private final NlpTask.RequestBuilder requestBuilder;
 
-    PassThroughProcessor(BertTokenizer tokenizer, BertPassThroughConfig config) {
-        this.bertRequestBuilder = new BertRequestBuilder(tokenizer, config.getTokenizationParams().maxSequenceLength());
+    PassThroughProcessor(NlpTokenizer tokenizer, BertPassThroughConfig config) {
+        this.requestBuilder = tokenizer.requestBuilder();
     }
 
     @Override
@@ -32,15 +33,15 @@ public class PassThroughProcessor implements NlpTask.Processor {
 
     @Override
     public NlpTask.RequestBuilder getRequestBuilder() {
-        return bertRequestBuilder;
+        return requestBuilder;
     }
 
     @Override
     public NlpTask.ResultProcessor getResultProcessor() {
-        return this::processResult;
+        return PassThroughProcessor::processResult;
     }
 
-    private InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+    private static InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
         return new PyTorchPassThroughResults(pyTorchResult.getInferenceResult());
     }
 }

+ 6 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessor.java

@@ -16,7 +16,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.SentimentAnalysisResult
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -25,10 +26,10 @@ import java.util.Locale;
 
 public class SentimentAnalysisProcessor implements NlpTask.Processor {
 
-    private final BertTokenizer tokenizer;
+    private final NlpTokenizer tokenizer;
     private final List<String> classLabels;
 
-    SentimentAnalysisProcessor(BertTokenizer tokenizer, SentimentAnalysisConfig config) {
+    SentimentAnalysisProcessor(NlpTokenizer tokenizer, SentimentAnalysisConfig config) {
         this.tokenizer = tokenizer;
         List<String> classLabels = config.getClassificationLabels();
         if (classLabels == null || classLabels.isEmpty()) {
@@ -60,7 +61,7 @@ public class SentimentAnalysisProcessor implements NlpTask.Processor {
     }
 
     NlpTask.Request buildRequest(String input, String requestId) throws IOException {
-        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize(input);
+        TokenizationResult tokenization = tokenizer.tokenize(input);
         return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
     }
 
@@ -69,7 +70,7 @@ public class SentimentAnalysisProcessor implements NlpTask.Processor {
         return this::processResult;
     }
 
-    InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+    InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
         if (pyTorchResult.getInferenceResult().length < 1) {
             return new WarningInferenceResults("Sentiment analysis result has no data");
         }

+ 6 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java

@@ -12,7 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 
 import java.util.Locale;
 
@@ -20,30 +20,30 @@ public enum TaskType {
 
     NER {
         @Override
-        public NlpTask.Processor createProcessor(BertTokenizer tokenizer, NlpConfig config) {
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new NerProcessor(tokenizer, (NerConfig) config);
         }
     },
     SENTIMENT_ANALYSIS {
         @Override
-        public NlpTask.Processor createProcessor(BertTokenizer tokenizer, NlpConfig config) {
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new SentimentAnalysisProcessor(tokenizer, (SentimentAnalysisConfig) config);
         }
     },
     FILL_MASK {
         @Override
-        public NlpTask.Processor createProcessor(BertTokenizer tokenizer, NlpConfig config) {
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new FillMaskProcessor(tokenizer, (FillMaskConfig) config);
         }
     },
     BERT_PASS_THROUGH {
         @Override
-        public NlpTask.Processor createProcessor(BertTokenizer tokenizer, NlpConfig config) {
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new PassThroughProcessor(tokenizer, (BertPassThroughConfig) config);
         }
     };
 
-    public NlpTask.Processor createProcessor(BertTokenizer tokenizer, NlpConfig config) {
+    public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
         throw new UnsupportedOperationException("json request must be specialised for task type [" + this.name() + "]");
     }
 

+ 72 - 72
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -7,6 +7,10 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -16,6 +20,7 @@ import java.util.List;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
+import java.util.function.Function;
 
 /**
  * Performs basic tokenization and normalization of input text
@@ -25,7 +30,7 @@ import java.util.TreeMap;
  * Derived from
  * https://github.com/huggingface/transformers/blob/ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532/src/transformers/tokenization_bert.py
  */
-public class BertTokenizer {
+public class BertTokenizer implements NlpTokenizer {
 
     public static final String UNKNOWN_TOKEN = "[UNK]";
     public static final String SEPARATOR_TOKEN = "[SEP]";
@@ -48,15 +53,18 @@ public class BertTokenizer {
     private final boolean doStripAccents;
     private final boolean withSpecialTokens;
     private final Set<String> neverSplit;
-
-    private BertTokenizer(
-                          List<String> originalVocab,
-                          SortedMap<String, Integer> vocab,
-                          boolean doLowerCase,
-                          boolean doTokenizeCjKChars,
-                          boolean doStripAccents,
-                          boolean withSpecialTokens,
-                          Set<String> neverSplit) {
+    private final int maxSequenceLength;
+    private final NlpTask.RequestBuilder requestBuilder;
+
+    protected BertTokenizer(List<String> originalVocab,
+                            SortedMap<String, Integer> vocab,
+                            boolean doLowerCase,
+                            boolean doTokenizeCjKChars,
+                            boolean doStripAccents,
+                            boolean withSpecialTokens,
+                            int maxSequenceLength,
+                            Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
+                            Set<String> neverSplit) {
         wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
         this.originalVocab = originalVocab;
         this.vocab = vocab;
@@ -65,6 +73,8 @@ public class BertTokenizer {
         this.doStripAccents = doStripAccents;
         this.withSpecialTokens = withSpecialTokens;
         this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
+        this.maxSequenceLength = maxSequenceLength;
+        this.requestBuilder = requestBuilderFactory.apply(this);
     }
 
     /**
@@ -76,6 +86,7 @@ public class BertTokenizer {
      * @param text Text to tokenize
      * @return Tokenized text, token Ids and map
      */
+    @Override
     public TokenizationResult tokenize(String text) {
         BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
 
@@ -126,79 +137,48 @@ public class BertTokenizer {
             tokenMap[i] = SPECIAL_TOKEN_POSITION;
         }
 
-        return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
-    }
-
-    public static class TokenizationResult {
-
-        String input;
-        List<String> vocab;
-        private final List<String> tokens;
-        private final int [] tokenIds;
-        private final int [] tokenMap;
-
-        public TokenizationResult(String input, List<String> vocab, List<String> tokens, int[] tokenIds, int[] tokenMap) {
-            assert tokens.size() == tokenIds.length;
-            assert tokenIds.length == tokenMap.length;
-            this.input = input;
-            this.vocab = vocab;
-            this.tokens = tokens;
-            this.tokenIds = tokenIds;
-            this.tokenMap = tokenMap;
-        }
-
-        public String getFromVocab(int tokenId) {
-            return vocab.get(tokenId);
+        if (tokenIds.length > maxSequenceLength) {
+            throw ExceptionsHelper.badRequestException(
+                "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
+                tokenIds.length,
+                maxSequenceLength
+            );
         }
 
-        /**
-         * The token strings from the tokenization process
-         * @return A list of tokens
-         */
-        public List<String> getTokens() {
-            return tokens;
-        }
-
-        /**
-         * The integer values of the tokens in {@link #getTokens()}
-         * @return A list of token Ids
-         */
-        public int[] getTokenIds() {
-            return tokenIds;
-        }
+        return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
+    }
 
-        /**
-         * Maps the token position to the position in the source text.
-         * Source words may be divided into more than one token so more
-         * than one token can map back to the source token
-         * @return Map of source token to
-         */
-        public int[] getTokenMap() {
-            return tokenMap;
-        }
+    @Override
+    public NlpTask.RequestBuilder requestBuilder() {
+        return requestBuilder;
+    }
 
-        public String getInput() {
-            return input;
-        }
+    public int getMaxSequenceLength() {
+        return maxSequenceLength;
     }
 
-    public static Builder builder(List<String> vocab) {
-        return new Builder(vocab);
+    public static Builder builder(List<String> vocab, Tokenization tokenization) {
+        return new Builder(vocab, tokenization);
     }
 
     public static class Builder {
 
-        private final List<String> originalVocab;
-        private final SortedMap<String, Integer> vocab;
-        private boolean doLowerCase = false;
-        private boolean doTokenizeCjKChars = true;
-        private boolean withSpecialTokens = true;
-        private Boolean doStripAccents = null;
-        private Set<String> neverSplit;
-
-        private Builder(List<String> vocab) {
+        protected final List<String> originalVocab;
+        protected final SortedMap<String, Integer> vocab;
+        protected boolean doLowerCase = false;
+        protected boolean doTokenizeCjKChars = true;
+        protected boolean withSpecialTokens = true;
+        protected int maxSequenceLength;
+        protected Boolean doStripAccents = null;
+        protected Set<String> neverSplit;
+        protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
+
+        protected Builder(List<String> vocab, Tokenization tokenization) {
             this.originalVocab = vocab;
             this.vocab = buildSortedVocab(vocab);
+            this.doLowerCase = tokenization.doLowerCase();
+            this.withSpecialTokens = tokenization.withSpecialTokens();
+            this.maxSequenceLength = tokenization.maxSequenceLength();
         }
 
         private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
@@ -229,6 +209,11 @@ public class BertTokenizer {
             return this;
         }
 
+        public Builder setMaxSequenceLength(int maxSequenceLength) {
+            this.maxSequenceLength = maxSequenceLength;
+            return this;
+        }
+
         /**
          * Include CLS and SEP tokens
          * @param withSpecialTokens if true include CLS and SEP tokens
@@ -239,6 +224,11 @@ public class BertTokenizer {
             return this;
         }
 
+        public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
+            this.requestBuilderFactory = requestBuilderFactory;
+            return this;
+        }
+
         public BertTokenizer build() {
             // if not set strip accents defaults to the value of doLowerCase
             if (doStripAccents == null) {
@@ -249,7 +239,17 @@ public class BertTokenizer {
                 neverSplit = Collections.emptySet();
             }
 
-            return new BertTokenizer(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, withSpecialTokens, neverSplit);
+            return new BertTokenizer(
+                originalVocab,
+                vocab,
+                doLowerCase,
+                doTokenizeCjKChars,
+                doStripAccents,
+                withSpecialTokens,
+                maxSequenceLength,
+                requestBuilderFactory,
+                neverSplit
+            );
         }
     }
 }

+ 39 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.DistilBertRequestBuilder;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
+import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.VOCABULARY;
+
+public interface NlpTokenizer {
+
+    TokenizationResult tokenize(String text);
+
+    NlpTask.RequestBuilder requestBuilder();
+
+    static NlpTokenizer build(Vocabulary vocabulary, Tokenization params) {
+        ExceptionsHelper.requireNonNull(params, TOKENIZATION);
+        ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);
+        if (params instanceof BertTokenization) {
+            return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(BertRequestBuilder::new).build();
+        }
+        if (params instanceof DistilBertTokenization) {
+            return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(DistilBertRequestBuilder::new).build();
+        }
+        throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
+    }
+}

+ 63 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -0,0 +1,63 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import java.util.List;
+
+public class TokenizationResult {
+
+    String input;
+    final List<String> vocab;
+    private final List<String> tokens;
+    private final int [] tokenIds;
+    private final int [] tokenMap;
+
+    public TokenizationResult(String input, List<String> vocab, List<String> tokens, int[] tokenIds, int[] tokenMap) {
+        assert tokens.size() == tokenIds.length;
+        assert tokenIds.length == tokenMap.length;
+        this.input = input;
+        this.vocab = vocab;
+        this.tokens = tokens;
+        this.tokenIds = tokenIds;
+        this.tokenMap = tokenMap;
+    }
+
+    public String getFromVocab(int tokenId) {
+        return vocab.get(tokenId);
+    }
+
+    /**
+     * The token strings from the tokenization process
+     * @return A list of tokens
+     */
+    public List<String> getTokens() {
+        return tokens;
+    }
+
+    /**
+     * The integer values of the tokens in {@link #getTokens()}
+     * @return A list of token Ids
+     */
+    public int[] getTokenIds() {
+        return tokenIds;
+    }
+
+    /**
+     * Maps the token position to the position in the source text.
+     * Source words may be divided into more than one token so more
+     * than one token can map back to the source token
+     * @return Map of source token to
+     */
+    public int[] getTokenMap() {
+        return tokenMap;
+    }
+
+    public String getInput() {
+        return input;
+    }
+}

+ 10 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 
 import java.io.IOException;
@@ -24,9 +25,11 @@ public class BertRequestBuilderTests extends ESTestCase {
 
     public void testBuildRequest() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build();
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            new BertTokenization(null, null, 512)
+        ).build();
 
-        BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer, 512);
+        BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
         NlpTask.Request request = requestBuilder.buildRequest("Elasticsearch fun", "request1");
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
@@ -41,10 +44,11 @@ public class BertRequestBuilderTests extends ESTestCase {
 
     public void testInputTooLarge() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build();
-
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            new BertTokenization(null, null, 5)
+        ).build();
         {
-            BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer, 5);
+            BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
             ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
                 () -> requestBuilder.buildRequest("Elasticsearch fun Elasticsearch fun Elasticsearch fun", "request1"));
 
@@ -52,7 +56,7 @@ public class BertRequestBuilderTests extends ESTestCase {
                 containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]"));
         }
         {
-            BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer, 5);
+            BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
             // input will become 3 tokens + the Class and Separator token = 5 which is
             // our max sequence length
             requestBuilder.buildRequest("Elasticsearch fun", "request1");

+ 64 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilderTests.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.hasSize;
+
+public class DistilBertRequestBuilderTests extends ESTestCase {
+
+    public void testBuildRequest() throws IOException {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            new DistilBertTokenization(null, null, 512)
+        ).build();
+
+        DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
+        BytesReference bytesReference = requestBuilder.buildRequest("Elasticsearch fun", "request1").processInput;
+
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(3, 0, 1, 2, 4), jsonDocAsMap.get("tokens"));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), jsonDocAsMap.get("arg_1"));
+    }
+
+    public void testInputTooLarge() throws IOException {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            new DistilBertTokenization(null, null, 5)
+        ).build();
+        {
+            DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
+            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+                () -> requestBuilder.buildRequest("Elasticsearch fun Elasticsearch fun Elasticsearch fun", "request1"));
+
+            assertThat(e.getMessage(),
+                containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]"));
+        }
+        {
+            DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
+            // input will become 3 tokens + the Class and Separator token = 5 which is
+            // our max sequence length
+            requestBuilder.buildRequest("Elasticsearch fun", "request1");
+        }
+    }
+}

+ 4 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -44,8 +45,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         int[] tokenMap = new int[] {0, 1, 2, 3, 4, 5};
         int[] tokenIds = new int[] {0, 1, 2, 3, 4, 5};
 
-        BertTokenizer.TokenizationResult tokenization = new BertTokenizer.TokenizationResult(input, vocab, tokens,
-            tokenIds, tokenMap);
+        TokenizationResult tokenization = new TokenizationResult(input, vocab, tokens, tokenIds, tokenMap);
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
 
@@ -66,8 +66,8 @@ public class FillMaskProcessorTests extends ESTestCase {
     }
 
     public void testProcessResults_GivenMissingTokens() {
-        BertTokenizer.TokenizationResult tokenization =
-            new BertTokenizer.TokenizationResult("", Collections.emptyList(), Collections.emptyList(),
+        TokenizationResult tokenization =
+            new TokenizationResult("", Collections.emptyList(), Collections.emptyList(),
             new int[] {}, new int[] {});
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);

+ 20 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -10,10 +10,13 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -81,15 +84,17 @@ public class NerProcessorTests extends ESTestCase {
 
     public void testProcessResults_GivenNoTokens() {
         NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values());
-        BertTokenizer.TokenizationResult tokenization = tokenize(Collections.emptyList(), "");
+        TokenizationResult tokenization = tokenize(Collections.emptyList(), "");
         NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("test", null, 0L, null));
         assertThat(result.getEntityGroups(), is(empty()));
     }
 
     public void testProcessResults() {
         NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values());
-        BertTokenizer.TokenizationResult tokenization = tokenize(Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
-            "Many use Elasticsearch in London");
+        TokenizationResult tokenization = tokenize(
+            Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
+            "Many use Elasticsearch in London"
+        );
         double[][] scores = {
             { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // many
             { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // use
@@ -123,8 +128,10 @@ public class NerProcessorTests extends ESTestCase {
         };
 
         NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(iobMap);
-        BertTokenizer.TokenizationResult tokenization = tokenize(Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
-            "Elasticsearch in London");
+        TokenizationResult tokenization = tokenize(
+            Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
+            "Elasticsearch in London"
+        );
 
         double[][] scores = {
             { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0}, // el
@@ -210,11 +217,14 @@ public class NerProcessorTests extends ESTestCase {
         assertThat(entityGroups.get(2).getLabel(), equalTo("organisation"));
     }
 
-    private static BertTokenizer.TokenizationResult tokenize(List<String> vocab, String input) {
-        BertTokenizer tokenizer = BertTokenizer.builder(vocab)
-            .setDoLowerCase(true)
-            .setWithSpecialTokens(false)
-            .build();
+    private static TokenizationResult tokenize(List<String> vocab, String input) {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            vocab,
+            randomFrom(
+                new BertTokenization(true, false, null),
+                new DistilBertTokenization(true, false, null)
+            )
+        ).setDoLowerCase(true).setWithSpecialTokens(false).build();
         return tokenizer.tokenize(input);
     }
 }

+ 4 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessorTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -51,7 +52,9 @@ public class SentimentAnalysisProcessorTests extends ESTestCase {
 
     public void testBuildRequest() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build();
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            new BertTokenization(null, null, 512)
+        ).build();
 
         SentimentAnalysisConfig config = new SentimentAnalysisConfig(new VocabularyConfig("test-index", "vocab"), null, null);
         SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(tokenizer, config);

+ 30 - 23
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -8,6 +8,8 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -17,20 +19,24 @@ import static org.hamcrest.Matchers.contains;
 public class BertTokenizerTests extends ESTestCase {
 
     public void testTokenize() {
-        BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList("Elastic", "##search", "fun"))
-            .setWithSpecialTokens(false).build();
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList("Elastic", "##search", "fun"),
+            new BertTokenization(null, false, null)
+        ).build();
 
-        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
+        TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun"));
         assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
     }
 
     public void testTokenizeAppendSpecialTokens() {
-        BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList(
-            "elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build();
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList( "elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            Tokenization.createDefault()
+        ).build();
 
-        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("elasticsearch fun");
+        TokenizationResult tokenization = tokenizer.tokenize("elasticsearch fun");
         assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]"));
         assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
@@ -40,12 +46,13 @@ public class BertTokenizerTests extends ESTestCase {
         final String specialToken = "SP001";
 
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", specialToken, BertTokenizer.UNKNOWN_TOKEN))
-            .setNeverSplit(Collections.singleton(specialToken))
-            .setWithSpecialTokens(false)
-            .build();
+            Arrays.asList("Elastic", "##search", "fun", specialToken, BertTokenizer.UNKNOWN_TOKEN),
+            Tokenization.createDefault()
+        ).setNeverSplit(Collections.singleton(specialToken))
+         .setWithSpecialTokens(false)
+         .build();
 
-        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
+        TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun"));
         assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
@@ -54,12 +61,13 @@ public class BertTokenizerTests extends ESTestCase {
     public void testDoLowerCase() {
         {
             BertTokenizer tokenizer = BertTokenizer.builder(
-                Arrays.asList("elastic", "##search", "fun", BertTokenizer.UNKNOWN_TOKEN))
-                .setDoLowerCase(false)
-                .setWithSpecialTokens(false)
-                .build();
+                Arrays.asList("elastic", "##search", "fun", BertTokenizer.UNKNOWN_TOKEN),
+                Tokenization.createDefault()
+            ).setDoLowerCase(false)
+             .setWithSpecialTokens(false)
+             .build();
 
-            BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
+            TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
             assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun"));
             assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
             assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
@@ -69,24 +77,23 @@ public class BertTokenizerTests extends ESTestCase {
         }
 
         {
-            BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList("elastic", "##search", "fun"))
+            BertTokenizer tokenizer = BertTokenizer.builder(Arrays.asList("elastic", "##search", "fun"), Tokenization.createDefault())
                 .setDoLowerCase(true)
                 .setWithSpecialTokens(false)
                 .build();
 
-            BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
+            TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
             assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
         }
     }
 
     public void testPunctuation() {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", ".", ",",
-                BertTokenizer.MASK_TOKEN, BertTokenizer.UNKNOWN_TOKEN))
-            .setWithSpecialTokens(false)
-            .build();
+            Arrays.asList("Elastic", "##search", "fun", ".", ",", BertTokenizer.MASK_TOKEN, BertTokenizer.UNKNOWN_TOKEN),
+            Tokenization.createDefault()
+        ).setWithSpecialTokens(false).build();
 
-        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch, fun.");
+        TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch, fun.");
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());