瀏覽代碼

[ML] removing unnecessary distil_bert tokenization type (#78135)

removes tokenization type in favor of having models wrapped and thus unifying the input parameters to all BERT models.
Benjamin Trent 4 年之前
父節點
當前提交
c07570cf06
共有 13 個文件被更改,包括 11 次插入338 次删除
  1. 0 15
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  2. 0 82
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenization.java
  3. 0 54
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenizationTests.java
  4. 1 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java
  5. 1 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java
  6. 2 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java
  7. 1 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java
  8. 2 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java
  9. 0 55
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilder.java
  10. 0 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  11. 0 101
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilderTests.java
  12. 1 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  13. 3 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

+ 0 - 15
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -28,7 +28,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResu
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@@ -206,13 +205,6 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 (p, c) -> BertTokenization.fromXContent(p, (boolean) c)
             )
         );
-        namedXContent.add(
-            new NamedXContentRegistry.Entry(
-                Tokenization.class,
-                DistilBertTokenization.NAME,
-                (p, c) -> DistilBertTokenization.fromXContent(p, (boolean) c)
-            )
-        );
 
         return namedXContent;
     }
@@ -318,13 +310,6 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
                 BertTokenization::new
             )
         );
-        namedWriteables.add(
-            new NamedWriteableRegistry.Entry(
-                Tokenization.class,
-                DistilBertTokenization.NAME.getPreferredName(),
-                DistilBertTokenization::new
-            )
-        );
 
         return namedWriteables;
     }

+ 0 - 82
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenization.java

@@ -1,82 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
-
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.ConstructingObjectParser;
-import org.elasticsearch.common.xcontent.ParseField;
-import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.core.Nullable;
-
-import java.io.IOException;
-
-public class DistilBertTokenization extends Tokenization {
-
-    public static final ParseField NAME = new ParseField("distil_bert");
-
-    public static ConstructingObjectParser<DistilBertTokenization, Void> createParser(boolean ignoreUnknownFields) {
-        ConstructingObjectParser<DistilBertTokenization, Void> parser = new ConstructingObjectParser<>(
-            "distil_bert_tokenization",
-            ignoreUnknownFields,
-            a -> new DistilBertTokenization((Boolean) a[0], (Boolean) a[1], (Integer) a[2])
-        );
-        Tokenization.declareCommonFields(parser);
-        return parser;
-    }
-
-    private static final ConstructingObjectParser<DistilBertTokenization, Void> LENIENT_PARSER = createParser(true);
-    private static final ConstructingObjectParser<DistilBertTokenization, Void> STRICT_PARSER = createParser(false);
-
-    public static DistilBertTokenization fromXContent(XContentParser parser, boolean lenient) {
-        return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
-    }
-
-    public DistilBertTokenization(
-        @Nullable Boolean doLowerCase,
-        @Nullable Boolean withSpecialTokens,
-        @Nullable Integer maxSequenceLength
-    ) {
-        super(doLowerCase, withSpecialTokens, maxSequenceLength);
-    }
-
-    public DistilBertTokenization(StreamInput in) throws IOException {
-        super(in);
-    }
-
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        super.writeTo(out);
-    }
-
-    XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
-        return builder;
-    }
-
-    @Override
-    public String getWriteableName() {
-        return NAME.getPreferredName();
-    }
-
-    @Override
-    public String getName() {
-        return NAME.getPreferredName();
-    }
-
-    @Override
-    public boolean equals(Object o) {
-        if (o == null || getClass() != o.getClass()) return false;
-        return super.equals(o);
-    }
-
-    @Override
-    public int hashCode() {
-        return super.hashCode();
-    }
-}

+ 0 - 54
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/DistilBertTokenizationTests.java

@@ -1,54 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
-
-import org.elasticsearch.Version;
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
-import org.junit.Before;
-
-import java.io.IOException;
-
-public class DistilBertTokenizationTests extends AbstractBWCSerializationTestCase<DistilBertTokenization> {
-
-    private boolean lenient;
-
-    @Before
-    public void chooseStrictOrLenient() {
-        lenient = randomBoolean();
-    }
-
-    @Override
-    protected DistilBertTokenization doParseInstance(XContentParser parser) throws IOException {
-        return DistilBertTokenization.createParser(lenient).apply(parser, null);
-    }
-
-    @Override
-    protected Writeable.Reader<DistilBertTokenization> instanceReader() {
-        return DistilBertTokenization::new;
-    }
-
-    @Override
-    protected DistilBertTokenization createTestInstance() {
-        return createRandom();
-    }
-
-    @Override
-    protected DistilBertTokenization mutateInstanceForVersion(DistilBertTokenization instance, Version version) {
-        return instance;
-    }
-
-    public static DistilBertTokenization createRandom() {
-        return new DistilBertTokenization(
-            randomBoolean() ? null : randomBoolean(),
-            randomBoolean() ? null : randomBoolean(),
-            randomBoolean() ? null : randomIntBetween(1, 1024)
-        );
-    }
-}

+ 1 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java

@@ -50,10 +50,7 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
     public static FillMaskConfig createRandom() {
         return new FillMaskConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ?
-                null :
-                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
-
+            randomBoolean() ? null : BertTokenizationTests.createRandom()
         );
     }
 }

+ 1 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java

@@ -50,9 +50,7 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
     public static NerConfig createRandom() {
         return new NerConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ?
-                null :
-                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10))
         );
     }

+ 2 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java

@@ -50,9 +50,7 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
     public static PassThroughConfig createRandom() {
         return new PassThroughConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ?
-                null :
-                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
-            );
+            randomBoolean() ? null : BertTokenizationTests.createRandom()
+        );
     }
 }

+ 1 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java

@@ -50,9 +50,7 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
     public static TextClassificationConfig createRandom() {
         return new TextClassificationConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ?
-                null :
-                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom()),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
             randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
             randomBoolean() ? null : randomIntBetween(-1, 10)
         );

+ 2 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java

@@ -50,9 +50,7 @@ public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEm
     public static TextEmbeddingConfig createRandom() {
         return new TextEmbeddingConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ?
-                null :
-                randomFrom(BertTokenizationTests.createRandom(), DistilBertTokenizationTests.createRandom())
-            );
+            randomBoolean() ? null : BertTokenizationTests.createRandom()
+        );
     }
 }

+ 0 - 55
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilder.java

@@ -1,55 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.ml.inference.nlp;
-
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
-
-import java.io.IOException;
-import java.util.List;
-
-public class DistilBertRequestBuilder implements NlpTask.RequestBuilder {
-
-    static final String REQUEST_ID = "request_id";
-    static final String TOKENS = "tokens";
-    static final String ARG1 = "arg_1";
-
-    private final BertTokenizer tokenizer;
-
-    public DistilBertRequestBuilder(BertTokenizer tokenizer) {
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
-        if (tokenizer.getPadToken().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
-                " token in its vocabulary");
-        }
-
-        TokenizationResult result = tokenizer.tokenize(inputs);
-        return new NlpTask.Request(result, jsonRequest(result, tokenizer.getPadToken().getAsInt(), requestId));
-    }
-
-    static BytesReference jsonRequest(TokenizationResult tokenization,
-                                      int padToken,
-                                      String requestId) throws IOException {
-        XContentBuilder builder = XContentFactory.jsonBuilder();
-        builder.startObject();
-        builder.field(REQUEST_ID, requestId);
-        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
-        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
-        builder.endObject();
-
-        // BytesReference.bytes closes the builder
-        return BytesReference.bytes(builder);
-    }
-}

+ 0 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -8,11 +8,9 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
-import org.elasticsearch.xpack.ml.inference.nlp.DistilBertRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 
@@ -36,9 +34,6 @@ public interface NlpTokenizer {
         if (params instanceof BertTokenization) {
             return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(BertRequestBuilder::new).build();
         }
-        if (params instanceof DistilBertTokenization) {
-            return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(DistilBertRequestBuilder::new).build();
-        }
         throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
     }
 }

+ 0 - 101
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilderTests.java

@@ -1,101 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.ml.inference.nlp;
-
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.xcontent.XContentHelper;
-import org.elasticsearch.common.xcontent.XContentType;
-import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
-
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-
-import static org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilderTests.nthListItemFromMap;
-import static org.hamcrest.Matchers.containsString;
-import static org.hamcrest.Matchers.hasSize;
-
-public class DistilBertRequestBuilderTests extends ESTestCase {
-
-    public void testBuildRequest() throws IOException {
-        BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
-            new DistilBertTokenization(null, null, 512)
-        ).build();
-
-        DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
-        BytesReference bytesReference = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1").processInput;
-
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
-
-        assertThat(jsonDocAsMap.keySet(), hasSize(3));
-        assertEquals("request1", jsonDocAsMap.get("request_id"));
-        assertEquals(Arrays.asList(3, 0, 1, 2, 4), nthListItemFromMap("tokens", 0, jsonDocAsMap));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
-    }
-
-    public void testInputTooLarge() throws IOException {
-        BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
-            new DistilBertTokenization(null, null, 5)
-        ).build();
-        {
-            DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
-            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-                () -> requestBuilder.buildRequest(List.of("Elasticsearch fun Elasticsearch fun Elasticsearch fun"), "request1"));
-
-            assertThat(e.getMessage(),
-                containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]"));
-        }
-        {
-            DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
-            // input will become 3 tokens + the Class and Separator token = 5 which is
-            // our max sequence length
-            requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1");
-        }
-    }
-
-    @SuppressWarnings("unchecked")
-    public void testBatchWithPadding() throws IOException {
-        BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList(BertTokenizer.PAD_TOKEN, BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN,
-                "Elastic", "##search", "fun",
-                "Pancake", "day",
-                "my", "little", "red", "car",
-                "God", "##zilla"
-            ),
-            new BertTokenization(null, null, 512)
-        ).build();
-
-        DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
-        NlpTask.Request request = requestBuilder.buildRequest(
-            List.of("Elasticsearch",
-                "my little red car",
-                "Godzilla day"), "request1");
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
-
-        assertEquals("request1", jsonDocAsMap.get("request_id"));
-        assertThat(jsonDocAsMap.keySet(), hasSize(3));
-        assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));
-        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_1"), hasSize(3));
-
-        assertEquals(Arrays.asList(1, 3, 4, 2, 0, 0), nthListItemFromMap("tokens", 0, jsonDocAsMap));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 0, 0), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
-
-        assertEquals(Arrays.asList(1, 8, 9, 10, 11, 2), nthListItemFromMap("tokens", 1, jsonDocAsMap));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 1, jsonDocAsMap));
-
-        assertEquals(Arrays.asList(1, 12, 13, 7, 2, 0), nthListItemFromMap("tokens", 2, jsonDocAsMap));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 0), nthListItemFromMap("arg_1", 2, jsonDocAsMap));
-    }
-}

+ 1 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -11,7 +11,6 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -221,10 +220,7 @@ public class NerProcessorTests extends ESTestCase {
     private static TokenizationResult tokenize(List<String> vocab, String input) {
         BertTokenizer tokenizer = BertTokenizer.builder(
             vocab,
-            randomFrom(
-                new BertTokenization(true, false, null),
-                new DistilBertTokenization(true, false, null)
-            )
+            new BertTokenization(true, false, null)
         ).setDoLowerCase(true).setWithSpecialTokens(false).build();
         return tokenizer.tokenize(List.of(input));
     }

+ 3 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -60,7 +60,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
                     BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
                 randomAlphaOfLength(10)
             ),
-            new DistilBertTokenization(null, null, 512));
+            new BertTokenization(null, null, 512));
 
         TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
@@ -69,7 +69,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
-        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertEquals("request1", jsonDocAsMap.get("request_id"));
         assertEquals(Arrays.asList(3, 0, 1, 2, 4), ((List<List<Integer>>)jsonDocAsMap.get("tokens")).get(0));
         assertEquals(Arrays.asList(1, 1, 1, 1, 1), ((List<List<Integer>>)jsonDocAsMap.get("arg_1")).get(0));