Просмотр исходного кода

[ML] refactoring internal tokenization logic for NLP (#83835)

This simplifies the internal logic used to pass tokenization results around while streamlining building the request sent to the model.

This helps lay some of the ground work for windowing as collapsing request building && token results will be required (as a single sequence could result in a batch request).

Additionally, many of the intellij warnings are addressed and code is modernized (i.e. taking advantage of records)
Benjamin Trent 3 лет назад
Родитель
Сommit
ac3d0beaf0
27 измененных файлов с 542 добавлено и 662 удалено
  1. 12 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  2. 0 71
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  3. 12 26
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  4. 0 66
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilder.java
  5. 13 36
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  6. 17 62
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  7. 2 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  8. 2 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  9. 2 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
  10. 8 40
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
  11. 118 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizationResult.java
  12. 26 101
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  13. 2 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/CharSeqTokenTrieNode.java
  14. 78 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizationResult.java
  15. 14 30
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java
  16. 5 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  17. 113 68
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java
  18. 11 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenFilter.java
  19. 2 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java
  20. 7 8
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertTokenizationResultTests.java
  21. 15 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  22. 7 8
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetTokenizationResultTests.java
  23. 4 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  24. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  25. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java
  26. 61 61
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java
  27. 9 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java

+ 12 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -29,6 +29,7 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@@ -201,7 +202,11 @@ public class DeploymentManager {
         try (
             InputStream stream = hit.getSourceRef().streamInput();
             XContentParser parser = XContentFactory.xContent(XContentType.JSON)
-                .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)
+                .createParser(
+                    XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry)
+                        .withDeprecationHandler(LoggingDeprecationHandler.INSTANCE),
+                    stream
+                )
         ) {
             return Vocabulary.createParser(true).apply(parser, null);
         } catch (IOException e) {
@@ -374,8 +379,8 @@ public class DeploymentManager {
                 NlpConfig nlpConfig = (NlpConfig) config;
                 NlpTask.Request request = processor.getRequestBuilder(nlpConfig)
                     .buildRequest(text, requestIdStr, nlpConfig.getTokenization().getTruncate());
-                logger.debug(() -> "Inference Request " + request.processInput.utf8ToString());
-                if (request.tokenization.anyTruncated()) {
+                logger.debug(() -> "Inference Request " + request.processInput().utf8ToString());
+                if (request.tokenization().anyTruncated()) {
                     logger.debug("[{}] [{}] input truncated", modelId, requestId);
                 }
                 processContext.getResultProcessor()
@@ -385,14 +390,14 @@ public class DeploymentManager {
                             inferenceResult -> processResult(
                                 inferenceResult,
                                 processContext,
-                                request.tokenization,
+                                request.tokenization(),
                                 processor.getResultProcessor((NlpConfig) config),
                                 this
                             ),
                             this::onFailure
                         )
                     );
-                processContext.process.get().writeInferenceRequest(request.processInput);
+                processContext.process.get().writeInferenceRequest(request.processInput());
             } catch (IOException e) {
                 logger.error(new ParameterizedMessage("[{}] error writing to inference process", processContext.task.getModelId()), e);
                 onFailure(ExceptionsHelper.serverError("Error writing to inference process", e));
@@ -448,8 +453,8 @@ public class DeploymentManager {
         private volatile Instant startTime;
         private volatile Integer inferenceThreads;
         private volatile Integer modelThreads;
-        private AtomicInteger rejectedExecutionCount = new AtomicInteger();
-        private AtomicInteger timeoutCount = new AtomicInteger();
+        private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
+        private final AtomicInteger timeoutCount = new AtomicInteger();
 
         ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
             this.task = Objects.requireNonNull(task);

+ 0 - 71
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -1,71 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.ml.inference.nlp;
-
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xcontent.XContentFactory;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.stream.Collectors;
-
-public class BertRequestBuilder implements NlpTask.RequestBuilder {
-
-    static final String REQUEST_ID = "request_id";
-    static final String TOKENS = "tokens";
-    static final String ARG1 = "arg_1";
-    static final String ARG2 = "arg_2";
-    static final String ARG3 = "arg_3";
-
-    private final NlpTokenizer tokenizer;
-
-    public BertRequestBuilder(NlpTokenizer tokenizer) {
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
-        if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
-        }
-
-        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
-            inputs.stream().map(s -> tokenizer.tokenize(s, truncate)).collect(Collectors.toList())
-        );
-        return buildRequest(tokenization, requestId);
-    }
-
-    @Override
-    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
-        if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
-        }
-        return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
-    }
-
-    static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException {
-        XContentBuilder builder = XContentFactory.jsonBuilder();
-        builder.startObject();
-        builder.field(REQUEST_ID, requestId);
-
-        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
-        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
-        int batchSize = tokenization.getTokenizations().size();
-        NlpTask.RequestBuilder.writeNonPaddedArguments(ARG2, batchSize, tokenization.getLongestSequenceLength(), i -> 0, builder);
-        NlpTask.RequestBuilder.writeNonPaddedArguments(ARG3, batchSize, tokenization.getLongestSequenceLength(), i -> i, builder);
-        builder.endObject();
-
-        // BytesReference.bytes closes the builder
-        return BytesReference.bytes(builder);
-    }
-
-}

+ 12 - 26
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -23,20 +23,14 @@ import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResu
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
+import java.util.OptionalInt;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 
-public class FillMaskProcessor implements NlpTask.Processor {
-
-    private final NlpTokenizer tokenizer;
+public class FillMaskProcessor extends NlpTask.Processor {
 
     FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public void close() {
-        tokenizer.close();
+        super(tokenizer);
     }
 
     @Override
@@ -97,7 +91,7 @@ public class FillMaskProcessor implements NlpTask.Processor {
         int numResults,
         String resultsField
     ) {
-        if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
+        if (tokenization.isEmpty()) {
             throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR);
         }
 
@@ -108,25 +102,20 @@ public class FillMaskProcessor implements NlpTask.Processor {
             );
         }
 
-        int maskTokenIndex = -1;
         int maskTokenId = tokenizer.getMaskTokenId().getAsInt();
-        for (int i = 0; i < tokenization.getTokenizations().get(0).getTokenIds().length; i++) {
-            if (tokenization.getTokenizations().get(0).getTokenIds()[i] == maskTokenId) {
-                maskTokenIndex = i;
-                break;
-            }
-        }
-        if (maskTokenIndex == -1) {
+        OptionalInt maskTokenIndex = tokenization.getTokenization(0).getTokenIndex(maskTokenId);
+        if (maskTokenIndex.isEmpty()) {
             throw new ElasticsearchStatusException(
-                "mask token id [{}] not found in the tokenization {}",
+                "mask token id [{}] not found in the tokenization",
                 RestStatus.INTERNAL_SERVER_ERROR,
-                maskTokenId,
-                List.of(tokenization.getTokenizations().get(0).getTokenIds())
+                maskTokenId
             );
         }
 
         // TODO - process all results in the batch
-        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
+        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(
+            pyTorchResult.getInferenceResult()[0][maskTokenIndex.getAsInt()]
+        );
 
         NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(
             // We need at least one to record the result
@@ -142,10 +131,7 @@ public class FillMaskProcessor implements NlpTask.Processor {
         }
         return new FillMaskResults(
             tokenization.getFromVocab(scoreAndIndices[0].index),
-            tokenization.getTokenizations()
-                .get(0)
-                .getInput()
-                .replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)),
+            tokenization.getTokenization(0).input().replace(tokenizer.getMaskToken(), tokenization.getFromVocab(scoreAndIndices[0].index)),
             results,
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
             scoreAndIndices[0].score,

+ 0 - 66
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilder.java

@@ -1,66 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.ml.inference.nlp;
-
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xcontent.XContentFactory;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
-import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.stream.Collectors;
-
-public class MPNetRequestBuilder implements NlpTask.RequestBuilder {
-
-    static final String REQUEST_ID = "request_id";
-    static final String TOKENS = "tokens";
-    static final String ARG1 = "arg_1";
-
-    private final NlpTokenizer tokenizer;
-
-    public MPNetRequestBuilder(NlpTokenizer tokenizer) {
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
-        if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
-        }
-
-        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
-            inputs.stream().map(s -> tokenizer.tokenize(s, truncate)).collect(Collectors.toList())
-        );
-        return buildRequest(tokenization, requestId);
-    }
-
-    @Override
-    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
-        if (tokenizer.getPadTokenId().isEmpty()) {
-            throw new IllegalStateException("The input tokenizer does not have a " + tokenizer.getPadToken() + " token in its vocabulary");
-        }
-        return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadTokenId().getAsInt(), requestId));
-    }
-
-    static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException {
-        XContentBuilder builder = XContentFactory.jsonBuilder();
-        builder.startObject();
-        builder.field(REQUEST_ID, requestId);
-
-        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
-        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
-        builder.endObject();
-
-        // BytesReference.bytes closes the builder
-        return BytesReference.bytes(builder);
-    }
-
-}

+ 13 - 36
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -32,7 +32,7 @@ import java.util.Optional;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 
-public class NerProcessor implements NlpTask.Processor {
+public class NerProcessor extends NlpTask.Processor {
 
     public enum Entity implements Writeable {
         NONE,
@@ -83,20 +83,14 @@ public class NerProcessor implements NlpTask.Processor {
     private final IobTag[] iobMap;
     private final String resultsField;
     private final boolean ignoreCase;
-    private final NlpTokenizer tokenizer;
 
     NerProcessor(NlpTokenizer tokenizer, NerConfig config) {
+        super(tokenizer);
         validate(config.getClassificationLabels());
         this.iobMap = buildIobMap(config.getClassificationLabels());
         this.requestBuilder = tokenizer.requestBuilder();
         this.resultsField = config.getResultsField();
         this.ignoreCase = config.getTokenization().doLowerCase();
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public void close() {
-        tokenizer.close();
     }
 
     /**
@@ -188,11 +182,7 @@ public class NerProcessor implements NlpTask.Processor {
         return annotatedResultBuilder.toString();
     }
 
-    static class NerResultProcessor implements NlpTask.ResultProcessor {
-        private final IobTag[] iobMap;
-        private final String resultsField;
-        private final boolean ignoreCase;
-
+    record NerResultProcessor(IobTag[] iobMap, String resultsField, boolean ignoreCase) implements NlpTask.ResultProcessor {
         NerResultProcessor(IobTag[] iobMap, String resultsField, boolean ignoreCase) {
             this.iobMap = iobMap;
             this.resultsField = Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD);
@@ -201,7 +191,7 @@ public class NerProcessor implements NlpTask.Processor {
 
         @Override
         public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
-            if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
+            if (tokenization.isEmpty()) {
                 throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR);
             }
             // TODO - process all results in the batch
@@ -213,18 +203,16 @@ public class NerProcessor implements NlpTask.Processor {
             // of maybe (1 + 0) / 2 = 0.5 while before softmax it'd be exp(10 - 5) / normalization
             // which could easily be close to 1.
             double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
-            List<TaggedToken> taggedTokens = tagTokens(tokenization.getTokenizations().get(0), normalizedScores, iobMap);
+            List<TaggedToken> taggedTokens = tagTokens(tokenization.getTokenization(0), normalizedScores, iobMap);
 
             List<NerResults.EntityGroup> entities = groupTaggedTokens(
                 taggedTokens,
-                ignoreCase
-                    ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT)
-                    : tokenization.getTokenizations().get(0).getInput()
+                ignoreCase ? tokenization.getTokenization(0).input().toLowerCase(Locale.ROOT) : tokenization.getTokenization(0).input()
             );
 
             return new NerResults(
                 resultsField,
-                buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities),
+                buildAnnotatedText(tokenization.getTokenization(0).input(), entities),
                 entities,
                 tokenization.anyTruncated()
             );
@@ -236,12 +224,12 @@ public class NerProcessor implements NlpTask.Processor {
          * in the original input replacing them with a single token that
          * gets labelled based on the average score of all its sub-tokens.
          */
-        static List<TaggedToken> tagTokens(TokenizationResult.Tokenization tokenization, double[][] scores, IobTag[] iobMap) {
+        static List<TaggedToken> tagTokens(TokenizationResult.Tokens tokenization, double[][] scores, IobTag[] iobMap) {
             List<TaggedToken> taggedTokens = new ArrayList<>();
             int startTokenIndex = 0;
             int numSpecialTokens = 0;
-            while (startTokenIndex < tokenization.getTokenIds().length) {
-                int inputMapping = tokenization.getTokenMap()[startTokenIndex];
+            while (startTokenIndex < tokenization.tokenIds().length) {
+                int inputMapping = tokenization.tokenIds()[startTokenIndex];
                 if (inputMapping < 0) {
                     // This token does not map to a token in the input (special tokens)
                     startTokenIndex++;
@@ -249,8 +237,7 @@ public class NerProcessor implements NlpTask.Processor {
                     continue;
                 }
                 int endTokenIndex = startTokenIndex;
-                while (endTokenIndex < tokenization.getTokenMap().length - 1
-                    && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) {
+                while (endTokenIndex < tokenization.tokenMap().length - 1 && tokenization.tokenMap()[endTokenIndex + 1] == inputMapping) {
                     endTokenIndex++;
                 }
                 double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);
@@ -268,7 +255,7 @@ public class NerProcessor implements NlpTask.Processor {
                 int maxScoreIndex = NlpHelpers.argmax(avgScores);
                 double score = avgScores[maxScoreIndex];
                 taggedTokens.add(
-                    new TaggedToken(tokenization.getTokens().get(startTokenIndex - numSpecialTokens), iobMap[maxScoreIndex], score)
+                    new TaggedToken(tokenization.tokens().get(startTokenIndex - numSpecialTokens), iobMap[maxScoreIndex], score)
                 );
                 startTokenIndex = endTokenIndex + 1;
             }
@@ -325,17 +312,7 @@ public class NerProcessor implements NlpTask.Processor {
             return entities;
         }
 
-        static class TaggedToken {
-            private final DelimitedToken token;
-            private final IobTag tag;
-            private final double score;
-
-            TaggedToken(DelimitedToken token, IobTag tag, double score) {
-                this.token = token;
-                this.tag = tag;
-                this.score = score;
-            }
-
+        record TaggedToken(DelimitedToken token, IobTag tag, double score) {
             @Override
             public String toString() {
                 return new StringBuilder("{").append("token:")

+ 17 - 62
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -11,7 +11,6 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.Releasable;
-import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
@@ -46,78 +45,37 @@ public class NlpTask {
     }
 
     public interface RequestBuilder {
-        @FunctionalInterface
-        interface IntToIntFunction {
-            int applyAsInt(int value);
-        }
-
-        @FunctionalInterface
-        interface TokenLookupFunction {
-            int apply(TokenizationResult.Tokenization tokenization, int index);
-        }
-
         Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException;
-
-        Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException;
-
-        static void writePaddedTokens(
-            String fieldName,
-            TokenizationResult tokenization,
-            int padToken,
-            TokenLookupFunction generator,
-            XContentBuilder builder
-        ) throws IOException {
-            builder.startArray(fieldName);
-            for (var inputTokens : tokenization.getTokenizations()) {
-                builder.startArray();
-                int i = 0;
-                for (; i < inputTokens.getTokenIds().length; i++) {
-                    builder.value(generator.apply(inputTokens, i));
-                }
-
-                for (; i < tokenization.getLongestSequenceLength(); i++) {
-                    builder.value(padToken);
-                }
-                builder.endArray();
-            }
-            builder.endArray();
-        }
-
-        static void writeNonPaddedArguments(
-            String fieldName,
-            int numTokenizations,
-            int longestSequenceLength,
-            IntToIntFunction generator,
-            XContentBuilder builder
-        ) throws IOException {
-            builder.startArray(fieldName);
-            for (int i = 0; i < numTokenizations; i++) {
-                builder.startArray();
-                for (int j = 0; j < longestSequenceLength; j++) {
-                    builder.value(generator.applyAsInt(j));
-                }
-                builder.endArray();
-            }
-            builder.endArray();
-        }
     }
 
     public interface ResultProcessor {
         InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult);
     }
 
-    public interface Processor extends Releasable {
+    public abstract static class Processor implements Releasable {
+
+        protected final NlpTokenizer tokenizer;
+
+        public Processor(NlpTokenizer tokenizer) {
+            this.tokenizer = tokenizer;
+        }
+
+        @Override
+        public void close() {
+            tokenizer.close();
+        }
+
         /**
          * Validate the task input string.
          * Throws an exception if the inputs fail validation
          *
          * @param inputs Text to validate
          */
-        void validateInputs(List<String> inputs);
+        public abstract void validateInputs(List<String> inputs);
 
-        RequestBuilder getRequestBuilder(NlpConfig config);
+        public abstract RequestBuilder getRequestBuilder(NlpConfig config);
 
-        ResultProcessor getResultProcessor(NlpConfig config);
+        public abstract ResultProcessor getResultProcessor(NlpConfig config);
     }
 
     public static String extractInput(TrainedModelInput input, Map<String, Object> doc) {
@@ -133,10 +91,7 @@ public class NlpTask {
         throw ExceptionsHelper.badRequestException("Input value [{}] for field [{}] must be a string", inputValue, inputField);
     }
 
-    public static class Request {
-        public final TokenizationResult tokenization;
-        public final BytesReference processInput;
-
+    public record Request(TokenizationResult tokenization, BytesReference processInput) {
         public Request(TokenizationResult tokenization, BytesReference processInput) {
             this.tokenization = Objects.requireNonNull(tokenization);
             this.processInput = Objects.requireNonNull(processInput);

+ 2 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -24,21 +24,13 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceCo
  * A NLP processor that directly returns the PyTorch result
  * without any post-processing
  */
-public class PassThroughProcessor implements NlpTask.Processor {
+public class PassThroughProcessor extends NlpTask.Processor {
 
     private final NlpTask.RequestBuilder requestBuilder;
-    private final NlpTokenizer tokenizer;
-    private final String resultsField;
 
     PassThroughProcessor(NlpTokenizer tokenizer, PassThroughConfig config) {
+        super(tokenizer);
         this.requestBuilder = tokenizer.requestBuilder();
-        this.resultsField = config.getResultsField();
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public void close() {
-        tokenizer.close();
     }
 
     @Override

+ 2 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -27,26 +27,20 @@ import java.util.stream.IntStream;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 
-public class TextClassificationProcessor implements NlpTask.Processor {
+public class TextClassificationProcessor extends NlpTask.Processor {
 
     private final NlpTask.RequestBuilder requestBuilder;
-    private final NlpTokenizer tokenizer;
     private final String[] classLabels;
     private final int numTopClasses;
 
     TextClassificationProcessor(NlpTokenizer tokenizer, TextClassificationConfig config) {
+        super(tokenizer);
         this.requestBuilder = tokenizer.requestBuilder();
         List<String> classLabels = config.getClassificationLabels();
         this.classLabels = classLabels.toArray(String[]::new);
         // negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size
         // This is a nice way of setting the value
         this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public void close() {
-        tokenizer.close();
     }
 
     @Override

+ 2 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -23,19 +23,13 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceCo
 /**
  * A NLP processor that returns a single double[] output from the model. Assumes that only one tensor is returned via inference
  **/
-public class TextEmbeddingProcessor implements NlpTask.Processor {
+public class TextEmbeddingProcessor extends NlpTask.Processor {
 
     private final NlpTask.RequestBuilder requestBuilder;
-    private final NlpTokenizer tokenizer;
 
     TextEmbeddingProcessor(NlpTokenizer tokenizer, TextEmbeddingConfig config) {
+        super(tokenizer);
         this.requestBuilder = tokenizer.requestBuilder();
-        this.tokenizer = tokenizer;
-    }
-
-    @Override
-    public void close() {
-        tokenizer.close();
     }
 
     @Override

+ 8 - 40
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -33,9 +33,8 @@ import java.util.stream.IntStream;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 
-public class ZeroShotClassificationProcessor implements NlpTask.Processor {
+public class ZeroShotClassificationProcessor extends NlpTask.Processor {
 
-    private final NlpTokenizer tokenizer;
     private final int entailmentPos;
     private final int contraPos;
     private final String[] labels;
@@ -44,7 +43,7 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
     private final String resultsField;
 
     ZeroShotClassificationProcessor(NlpTokenizer tokenizer, ZeroShotClassificationConfig config) {
-        this.tokenizer = tokenizer;
+        super(tokenizer);
         List<String> lowerCased = config.getClassificationLabels()
             .stream()
             .map(s -> s.toLowerCase(Locale.ROOT))
@@ -62,11 +61,6 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
         this.resultsField = config.getResultsField();
     }
 
-    @Override
-    public void close() {
-        tokenizer.close();
-    }
-
     @Override
     public void validateInputs(List<String> inputs) {
         // nothing to validate
@@ -103,51 +97,25 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
         return new ResultProcessor(entailmentPos, contraPos, labelsValue, isMultiLabelValue, resultsFieldValue);
     }
 
-    static class RequestBuilder implements NlpTask.RequestBuilder {
-
-        private final NlpTokenizer tokenizer;
-        private final String[] labels;
-        private final String hypothesisTemplate;
-
-        RequestBuilder(NlpTokenizer tokenizer, String[] labels, String hypothesisTemplate) {
-            this.tokenizer = tokenizer;
-            this.labels = labels;
-            this.hypothesisTemplate = hypothesisTemplate;
-        }
+    record RequestBuilder(NlpTokenizer tokenizer, String[] labels, String hypothesisTemplate) implements NlpTask.RequestBuilder {
 
         @Override
         public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
             if (inputs.size() > 1) {
                 throw ExceptionsHelper.badRequestException("Unable to do zero-shot classification on more than one text input at a time");
             }
-            List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
+            List<TokenizationResult.Tokens> tokenizations = new ArrayList<>(labels.length);
             for (String label : labels) {
                 tokenizations.add(tokenizer.tokenize(inputs.get(0), LoggerMessageFormat.format(null, hypothesisTemplate, label), truncate));
             }
             TokenizationResult result = tokenizer.buildTokenizationResult(tokenizations);
-            return buildRequest(result, requestId);
-        }
-
-        @Override
-        public NlpTask.Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException {
-            return tokenizer.requestBuilder().buildRequest(tokenizationResult, requestId);
+            return result.buildRequest(requestId, truncate);
         }
     }
 
-    static class ResultProcessor implements NlpTask.ResultProcessor {
-        private final int entailmentPos;
-        private final int contraPos;
-        private final String[] labels;
-        private final boolean isMultiLabel;
-        private final String resultsField;
-
-        ResultProcessor(int entailmentPos, int contraPos, String[] labels, boolean isMultiLabel, String resultsField) {
-            this.entailmentPos = entailmentPos;
-            this.contraPos = contraPos;
-            this.labels = labels;
-            this.isMultiLabel = isMultiLabel;
-            this.resultsField = resultsField;
-        }
+    record ResultProcessor(int entailmentPos, int contraPos, String[] labels, boolean isMultiLabel, String resultsField)
+        implements
+            NlpTask.ResultProcessor {
 
         @Override
         public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {

+ 118 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizationResult.java

@@ -0,0 +1,118 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+public class BertTokenizationResult extends TokenizationResult {
+
+    static final String REQUEST_ID = "request_id";
+    static final String TOKENS = "tokens";
+    static final String ARG1 = "arg_1";
+    static final String ARG2 = "arg_2";
+    static final String ARG3 = "arg_3";
+
+    public BertTokenizationResult(List<String> vocab, List<TokenizationResult.Tokens> tokenizations, int padTokenId) {
+        super(vocab, tokenizations, padTokenId);
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(String requestId, Tokenization.Truncate t) throws IOException {
+        XContentBuilder builder = XContentFactory.jsonBuilder();
+        builder.startObject();
+        builder.field(REQUEST_ID, requestId);
+        writePaddedTokens(TOKENS, builder);
+        writeAttentionMask(ARG1, builder);
+        writeTokenTypeIds(ARG2, builder);
+        writePositionIds(ARG3, builder);
+        builder.endObject();
+
+        // BytesReference.bytes closes the builder
+        BytesReference jsonRequest = BytesReference.bytes(builder);
+        return new NlpTask.Request(this, jsonRequest);
+    }
+
+    static class BertTokensBuilder implements TokensBuilder {
+        protected final Stream.Builder<IntStream> tokenIds;
+        protected final Stream.Builder<IntStream> tokenMap;
+        protected final boolean withSpecialTokens;
+        protected final int clsTokenId;
+        protected final int sepTokenId;
+
+        BertTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) {
+            this.withSpecialTokens = withSpecialTokens;
+            this.clsTokenId = clsTokenId;
+            this.sepTokenId = sepTokenId;
+            this.tokenIds = Stream.builder();
+            this.tokenMap = Stream.builder();
+        }
+
+        @Override
+        public TokensBuilder addSequence(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(clsTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            return this;
+        }
+
+        @Override
+        public TokensBuilder addSequencePair(
+            List<Integer> tokenId1s,
+            List<Integer> tokenMap1,
+            List<Integer> tokenId2s,
+            List<Integer> tokenMap2
+        ) {
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(clsTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(tokenId1s.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenMap1.stream().mapToInt(Integer::valueOf));
+            int previouslyFinalMap = tokenMap1.get(tokenMap1.size() - 1);
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap));
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            return this;
+        }
+
+        @Override
+        public Tokens build(String input, boolean truncated, List<? extends DelimitedToken> allTokens) {
+            return new Tokens(
+                input,
+                allTokens,
+                truncated,
+                tokenIds.build().flatMapToInt(Function.identity()).toArray(),
+                tokenMap.build().flatMapToInt(Function.identity()).toArray()
+            );
+        }
+    }
+}

+ 26 - 101
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -11,7 +11,6 @@ import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 
 import java.io.IOException;
@@ -23,10 +22,7 @@ import java.util.OptionalInt;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
-import java.util.function.Function;
 import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.Stream;
 
 /**
  * Performs basic tokenization and normalization of input text
@@ -49,17 +45,17 @@ public class BertTokenizer implements NlpTokenizer {
     private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
 
     private final WordPieceAnalyzer wordPieceAnalyzer;
-    private final List<String> originalVocab;
+    protected final List<String> originalVocab;
     // TODO Not sure this needs to be a sorted map
     private final SortedMap<String, Integer> vocab;
     protected final boolean withSpecialTokens;
     private final int maxSequenceLength;
-    private final NlpTask.RequestBuilder requestBuilder;
     private final String sepToken;
     protected final int sepTokenId;
     private final String clsToken;
     private final int clsTokenId;
     private final String padToken;
+    protected final int padTokenId;
     private final String maskToken;
     private final String unknownToken;
 
@@ -71,7 +67,6 @@ public class BertTokenizer implements NlpTokenizer {
         boolean doStripAccents,
         boolean withSpecialTokens,
         int maxSequenceLength,
-        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
         Set<String> neverSplit
     ) {
         this(
@@ -82,7 +77,6 @@ public class BertTokenizer implements NlpTokenizer {
             doStripAccents,
             withSpecialTokens,
             maxSequenceLength,
-            requestBuilderFactory,
             Sets.union(neverSplit, NEVER_SPLIT),
             SEPARATOR_TOKEN,
             CLASS_TOKEN,
@@ -100,7 +94,6 @@ public class BertTokenizer implements NlpTokenizer {
         boolean doStripAccents,
         boolean withSpecialTokens,
         int maxSequenceLength,
-        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
         Set<String> neverSplit,
         String sepToken,
         String clsToken,
@@ -120,13 +113,13 @@ public class BertTokenizer implements NlpTokenizer {
         this.vocab = vocab;
         this.withSpecialTokens = withSpecialTokens;
         this.maxSequenceLength = maxSequenceLength;
-        this.requestBuilder = requestBuilderFactory.apply(this);
         if (vocab.containsKey(unknownToken) == false) {
             throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", unknownToken);
         }
         if (vocab.containsKey(padToken) == false) {
             throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", padToken);
         }
+        this.padTokenId = vocab.get(padToken);
 
         if (withSpecialTokens) {
             Set<String> missingSpecialTokens = Sets.difference(Set.of(sepToken, clsToken), vocab.keySet());
@@ -188,12 +181,12 @@ public class BertTokenizer implements NlpTokenizer {
     }
 
     @Override
-    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations) {
-        TokenizationResult tokenizationResult = new TokenizationResult(originalVocab);
-        for (TokenizationResult.Tokenization tokenization : tokenizations) {
-            tokenizationResult.addTokenization(tokenization);
-        }
-        return tokenizationResult;
+    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
+        return new BertTokenizationResult(originalVocab, tokenizations, vocab.get(this.padToken));
+    }
+
+    TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
+        return new BertTokenizationResult.BertTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
     }
 
     /**
@@ -208,7 +201,7 @@ public class BertTokenizer implements NlpTokenizer {
      * @return A {@link Tokenization}
      */
     @Override
-    public TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncate truncate) {
+    public TokenizationResult.Tokens tokenize(String seq, Tokenization.Truncate truncate) {
         var innerResult = innerTokenize(seq);
         List<WordPieceTokenFilter.WordPieceToken> wordPieceTokenIds = innerResult.tokens;
         List<Integer> tokenPositionMap = innerResult.tokenPositionMap;
@@ -229,21 +222,14 @@ public class BertTokenizer implements NlpTokenizer {
                 );
             }
         }
-        BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(
+        return createTokensBuilder(clsTokenId, sepTokenId, withSpecialTokens).addSequence(
             wordPieceTokenIds.stream().map(WordPieceTokenFilter.WordPieceToken::getEncoding).collect(Collectors.toList()),
             tokenPositionMap
-        ).addEndTokensIfNecessary();
-        return new TokenizationResult.Tokenization(
-            seq,
-            innerResult.tokens,
-            isTruncated,
-            bertTokenizationBuilder.buildIds(),
-            bertTokenizationBuilder.buildMap()
-        );
+        ).build(seq, isTruncated, innerResult.tokens);
     }
 
     @Override
-    public TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate) {
+    public TokenizationResult.Tokens tokenize(String seq1, String seq2, Tokenization.Truncate truncate) {
         var innerResultSeq1 = innerTokenize(seq1);
         List<WordPieceTokenFilter.WordPieceToken> wordPieceTokenIdsSeq1 = innerResultSeq1.tokens;
         List<Integer> tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap;
@@ -302,28 +288,21 @@ public class BertTokenizer implements NlpTokenizer {
                 );
             }
         }
-        BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(
-            wordPieceTokenIdsSeq1.stream().map(WordPieceTokenFilter.WordPieceToken::getEncoding).collect(Collectors.toList()),
-            tokenPositionMapSeq1
-        )
-            .addTokens(
-                wordPieceTokenIdsSeq2.stream().map(WordPieceTokenFilter.WordPieceToken::getEncoding).collect(Collectors.toList()),
-                tokenPositionMapSeq2
-            )
-            .addEndTokensIfNecessary();
         List<WordPieceTokenFilter.WordPieceToken> tokens = new ArrayList<>(innerResultSeq1.tokens);
         tokens.addAll(innerResultSeq2.tokens);
-        return new TokenizationResult.Tokenization(
-            seq1 + seq2,
-            tokens,
-            isTruncated,
-            bertTokenizationBuilder.buildIds(),
-            bertTokenizationBuilder.buildMap()
-        );
+        return createTokensBuilder(clsTokenId, sepTokenId, withSpecialTokens).addSequencePair(
+            wordPieceTokenIdsSeq1.stream().map(WordPieceTokenFilter.WordPieceToken::getEncoding).collect(Collectors.toList()),
+            tokenPositionMapSeq1,
+            wordPieceTokenIdsSeq2.stream().map(WordPieceTokenFilter.WordPieceToken::getEncoding).collect(Collectors.toList()),
+            tokenPositionMapSeq2
+        ).build(seq1 + seq2, isTruncated, tokens);
     }
 
-    protected BertTokenizationBuilder bertTokenizationBuilder() {
-        return new BertTokenizationBuilder();
+    @Override
+    public NlpTask.RequestBuilder requestBuilder() {
+        return (inputs, requestId, truncate) -> buildTokenizationResult(
+            inputs.stream().map(s -> tokenize(s, truncate)).collect(Collectors.toList())
+        ).buildRequest(requestId, truncate);
     }
 
     protected int getNumExtraTokensForSeqPair() {
@@ -361,11 +340,6 @@ public class BertTokenizer implements NlpTokenizer {
         }
     }
 
-    @Override
-    public NlpTask.RequestBuilder requestBuilder() {
-        return requestBuilder;
-    }
-
     public int getMaxSequenceLength() {
         return maxSequenceLength;
     }
@@ -374,59 +348,16 @@ public class BertTokenizer implements NlpTokenizer {
         return new Builder(vocab, tokenization);
     }
 
-    protected class BertTokenizationBuilder {
-        Stream.Builder<IntStream> tokenIds;
-        Stream.Builder<IntStream> tokenMap;
-        int numSeq;
-
-        BertTokenizationBuilder() {
-            tokenIds = Stream.builder();
-            tokenMap = Stream.builder();
-            if (withSpecialTokens) {
-                tokenIds.add(IntStream.of(clsTokenId));
-                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
-            }
-        }
-
-        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
-            if (numSeq > 0 && withSpecialTokens) {
-                tokenIds.add(IntStream.of(sepTokenId));
-                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
-            }
-            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
-            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
-            numSeq++;
-            return this;
-        }
-
-        BertTokenizationBuilder addEndTokensIfNecessary() {
-            if (withSpecialTokens) {
-                tokenIds.add(IntStream.of(sepTokenId));
-                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
-            }
-            return this;
-        }
-
-        int[] buildIds() {
-            return tokenIds.build().flatMapToInt(Function.identity()).toArray();
-        }
-
-        int[] buildMap() {
-            return tokenMap.build().flatMapToInt(Function.identity()).toArray();
-        }
-    }
-
     public static class Builder {
 
         protected final List<String> originalVocab;
         protected final SortedMap<String, Integer> vocab;
-        protected boolean doLowerCase = false;
+        protected boolean doLowerCase;
         protected boolean doTokenizeCjKChars = true;
-        protected boolean withSpecialTokens = true;
+        protected boolean withSpecialTokens;
         protected int maxSequenceLength;
         protected Boolean doStripAccents = null;
         protected Set<String> neverSplit;
-        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
 
         protected Builder(List<String> vocab, Tokenization tokenization) {
             this.originalVocab = vocab;
@@ -479,11 +410,6 @@ public class BertTokenizer implements NlpTokenizer {
             return this;
         }
 
-        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
-            this.requestBuilderFactory = requestBuilderFactory;
-            return this;
-        }
-
         public BertTokenizer build() {
             // if not set strip accents defaults to the value of doLowerCase
             if (doStripAccents == null) {
@@ -502,7 +428,6 @@ public class BertTokenizer implements NlpTokenizer {
                 doStripAccents,
                 withSpecialTokens,
                 maxSequenceLength,
-                requestBuilderFactory,
                 neverSplit
             );
         }

+ 2 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/CharSeqTokenTrieNode.java

@@ -16,13 +16,11 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
 
-public class CharSeqTokenTrieNode {
+public record CharSeqTokenTrieNode(CharArrayMap<CharSeqTokenTrieNode> children) {
 
     public static final CharSeqTokenTrieNode EMPTY = new CharSeqTokenTrieNode(new CharArrayMap<>(0, false));
 
-    private final CharArrayMap<CharSeqTokenTrieNode> children;
-
-    private CharSeqTokenTrieNode(CharArrayMap<CharSeqTokenTrieNode> children) {
+    public CharSeqTokenTrieNode(CharArrayMap<CharSeqTokenTrieNode> children) {
         this.children = Objects.requireNonNull(children);
     }
 

+ 78 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizationResult.java

@@ -0,0 +1,78 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.IntStream;
+
+public class MPNetTokenizationResult extends TokenizationResult {
+
+    static final String REQUEST_ID = "request_id";
+    static final String TOKENS = "tokens";
+    static final String ARG1 = "arg_1";
+
+    public MPNetTokenizationResult(List<String> vocab, List<Tokens> tokenizations, int padTokenId) {
+        super(vocab, tokenizations, padTokenId);
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(String requestId, Tokenization.Truncate t) throws IOException {
+        XContentBuilder builder = XContentFactory.jsonBuilder();
+        builder.startObject();
+        builder.field(REQUEST_ID, requestId);
+        writePaddedTokens(TOKENS, builder);
+        writeAttentionMask(ARG1, builder);
+        builder.endObject();
+
+        // BytesReference.bytes closes the builder
+        BytesReference jsonRequest = BytesReference.bytes(builder);
+        return new NlpTask.Request(this, jsonRequest);
+    }
+
+    static class MPNetTokensBuilder extends BertTokenizationResult.BertTokensBuilder {
+
+        MPNetTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) {
+            super(withSpecialTokens, clsTokenId, sepTokenId);
+        }
+
+        @Override
+        public TokensBuilder addSequencePair(
+            List<Integer> tokenId1s,
+            List<Integer> tokenMap1,
+            List<Integer> tokenId2s,
+            List<Integer> tokenMap2
+        ) {
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(clsTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(tokenId1s.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenMap1.stream().mapToInt(Integer::valueOf));
+            int previouslyFinalMap = tokenMap1.get(tokenMap1.size() - 1);
+            // MPNet adds two `</s>` betwee sequence pairs
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId, sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION, SPECIAL_TOKEN_POSITION));
+            }
+            tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf));
+            tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap));
+            if (withSpecialTokens) {
+                tokenIds.add(IntStream.of(sepTokenId));
+                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
+            }
+            return this;
+        }
+    }
+}

+ 14 - 30
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizer.java

@@ -8,7 +8,6 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
-import org.elasticsearch.xpack.ml.inference.nlp.MPNetRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 
 import java.util.Collections;
@@ -16,8 +15,7 @@ import java.util.List;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
-import java.util.function.Function;
-import java.util.stream.IntStream;
+import java.util.stream.Collectors;
 
 /**
  * Performs basic tokenization and normalization of input text
@@ -41,7 +39,6 @@ public class MPNetTokenizer extends BertTokenizer {
         boolean doStripAccents,
         boolean withSpecialTokens,
         int maxSequenceLength,
-        Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
         Set<String> neverSplit
     ) {
         super(
@@ -52,7 +49,6 @@ public class MPNetTokenizer extends BertTokenizer {
             doStripAccents,
             withSpecialTokens,
             maxSequenceLength,
-            requestBuilderFactory,
             Sets.union(neverSplit, NEVER_SPLIT),
             SEPARATOR_TOKEN,
             CLASS_TOKEN,
@@ -67,25 +63,20 @@ public class MPNetTokenizer extends BertTokenizer {
         return 4;
     }
 
-    @Override
-    protected BertTokenizationBuilder bertTokenizationBuilder() {
-        return new MPNetTokenizationBuilder();
+    TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
+        return new MPNetTokenizationResult.MPNetTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
     }
 
-    protected class MPNetTokenizationBuilder extends BertTokenizationBuilder {
-
-        @Override
-        BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
-            if (numSeq > 0 && withSpecialTokens) {
-                tokenIds.add(IntStream.of(sepTokenId, sepTokenId));
-                tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION, SPECIAL_TOKEN_POSITION));
-            }
-            tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
-            tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
-            numSeq++;
-            return this;
-        }
+    @Override
+    public NlpTask.RequestBuilder requestBuilder() {
+        return (inputs, requestId, truncate) -> buildTokenizationResult(
+            inputs.stream().map(s -> tokenize(s, truncate)).collect(Collectors.toList())
+        ).buildRequest(requestId, truncate);
+    }
 
+    @Override
+    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
+        return new MPNetTokenizationResult(originalVocab, tokenizations, getPadTokenId().orElseThrow());
     }
 
     public static Builder mpBuilder(List<String> vocab, Tokenization tokenization) {
@@ -96,13 +87,12 @@ public class MPNetTokenizer extends BertTokenizer {
 
         protected final List<String> originalVocab;
         protected final SortedMap<String, Integer> vocab;
-        protected boolean doLowerCase = false;
+        protected boolean doLowerCase;
         protected boolean doTokenizeCjKChars = true;
-        protected boolean withSpecialTokens = true;
+        protected boolean withSpecialTokens;
         protected int maxSequenceLength;
         protected Boolean doStripAccents = null;
         protected Set<String> neverSplit;
-        protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = MPNetRequestBuilder::new;
 
         protected Builder(List<String> vocab, Tokenization tokenization) {
             this.originalVocab = vocab;
@@ -155,11 +145,6 @@ public class MPNetTokenizer extends BertTokenizer {
             return this;
         }
 
-        public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
-            this.requestBuilderFactory = requestBuilderFactory;
-            return this;
-        }
-
         public MPNetTokenizer build() {
             // if not set strip accents defaults to the value of doLowerCase
             if (doStripAccents == null) {
@@ -178,7 +163,6 @@ public class MPNetTokenizer extends BertTokenizer {
                 doStripAccents,
                 withSpecialTokens,
                 maxSequenceLength,
-                requestBuilderFactory,
                 neverSplit
             );
         }

+ 5 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -12,8 +12,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
-import org.elasticsearch.xpack.ml.inference.nlp.MPNetRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 
@@ -25,11 +23,11 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.V
 
 public interface NlpTokenizer extends Releasable {
 
-    TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations);
+    TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations);
 
-    TokenizationResult.Tokenization tokenize(String seq, Tokenization.Truncate truncate);
+    TokenizationResult.Tokens tokenize(String seq, Tokenization.Truncate truncate);
 
-    TokenizationResult.Tokenization tokenize(String seq1, String seq2, Tokenization.Truncate truncate);
+    TokenizationResult.Tokens tokenize(String seq1, String seq2, Tokenization.Truncate truncate);
 
     NlpTask.RequestBuilder requestBuilder();
 
@@ -45,10 +43,10 @@ public interface NlpTokenizer extends Releasable {
         ExceptionsHelper.requireNonNull(params, TOKENIZATION);
         ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);
         if (params instanceof BertTokenization) {
-            return BertTokenizer.builder(vocabulary.get(), params).setRequestBuilderFactory(BertRequestBuilder::new).build();
+            return BertTokenizer.builder(vocabulary.get(), params).build();
         }
         if (params instanceof MPNetTokenization) {
-            return MPNetTokenizer.mpBuilder(vocabulary.get(), params).setRequestBuilderFactory(MPNetRequestBuilder::new).build();
+            return MPNetTokenizer.mpBuilder(vocabulary.get(), params).build();
         }
         throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
     }

+ 113 - 68
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -7,105 +7,150 @@
 
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
-import java.util.ArrayList;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
+
+import java.io.IOException;
 import java.util.List;
+import java.util.OptionalInt;
+import java.util.stream.IntStream;
 
-public class TokenizationResult {
+public abstract class TokenizationResult {
+    public static final int SPECIAL_TOKEN_POSITION = -1;
 
     private final List<String> vocab;
-    private final List<Tokenization> tokenizations = new ArrayList<>();
-    private int maxLength;
+    private final List<Tokens> tokens;
+    private final int maxLength;
+    private final int padTokenId;
 
-    public TokenizationResult(List<String> vocab) {
+    protected TokenizationResult(List<String> vocab, List<Tokens> tokenizations, int padTokenId) {
         this.vocab = vocab;
-        this.maxLength = -1;
+        this.tokens = tokenizations;
+        this.padTokenId = padTokenId;
+        int max = 0;
+        for (Tokens tokenization : tokenizations) {
+            max = Math.max(tokenization.tokenIds.length, max);
+        }
+        this.maxLength = max;
     }
 
-    public boolean anyTruncated() {
-        return tokenizations.stream().anyMatch(Tokenization::isTruncated);
+    List<Tokens> getTokens() {
+        return tokens;
     }
 
     public String getFromVocab(int tokenId) {
         return vocab.get(tokenId);
     }
 
-    public List<Tokenization> getTokenizations() {
-        return tokenizations;
+    public Tokens getTokenization(int tokenizationIndex) {
+        return tokens.get(tokenizationIndex);
     }
 
-    public void addTokenization(
-        String input,
-        boolean isTruncated,
-        List<WordPieceTokenFilter.WordPieceToken> tokens,
-        int[] tokenIds,
-        int[] tokenMap
-    ) {
-        maxLength = Math.max(maxLength, tokenIds.length);
-        tokenizations.add(new Tokenization(input, tokens, isTruncated, tokenIds, tokenMap));
+    public boolean anyTruncated() {
+        return tokens.stream().anyMatch(Tokens::truncated);
     }
 
-    public void addTokenization(Tokenization tokenization) {
-        maxLength = Math.max(maxLength, tokenization.tokenIds.length);
-        tokenizations.add(tokenization);
+    public boolean isEmpty() {
+        return this.tokens.isEmpty() || this.tokens.stream().allMatch(t -> t.tokenIds.length == 0);
     }
 
-    public int getLongestSequenceLength() {
-        return maxLength;
+    public abstract NlpTask.Request buildRequest(String requestId, Tokenization.Truncate t) throws IOException;
+
+    protected void writePaddedTokens(String fieldName, XContentBuilder builder) throws IOException {
+        builder.startArray(fieldName);
+        for (var inputTokens : tokens) {
+            builder.startArray();
+
+            // Note, cannot write the array directly as the internal builder code writes start/end array values
+            for (int t : inputTokens.tokenIds) {
+                builder.value(t);
+            }
+            for (int i = inputTokens.tokenIds.length; i < maxLength; i++) {
+                builder.value(padTokenId);
+            }
+            builder.endArray();
+        }
+        builder.endArray();
     }
 
-    public static class Tokenization {
-
-        private final String input;
-        private final List<WordPieceTokenFilter.WordPieceToken> tokens;
-        private final int[] tokenIds;
-        private final int[] tokenMap;
-        private final boolean truncated;
-
-        public Tokenization(
-            String input,
-            List<WordPieceTokenFilter.WordPieceToken> tokens,
-            boolean truncated,
-            int[] tokenIds,
-            int[] tokenMap
-        ) {
-            assert tokenIds.length == tokenMap.length;
-            this.input = input;
-            this.tokens = tokens;
-            this.tokenIds = tokenIds;
-            this.tokenMap = tokenMap;
-            this.truncated = truncated;
+    protected void writeAttentionMask(String fieldName, XContentBuilder builder) throws IOException {
+        builder.startArray(fieldName);
+        for (var inputTokens : tokens) {
+            builder.startArray();
+            // Note, cannot write the array directly as the internal builder code writes start/end array values
+            for (int ignored : inputTokens.tokenIds) {
+                builder.value(1);
+            }
+            for (int i = inputTokens.tokenIds.length; i < maxLength; i++) {
+                builder.value(padTokenId);
+            }
+            builder.endArray();
         }
+        builder.endArray();
+    }
 
-        /**
-         * The integer values of the tokens}
-         *
-         * @return A list of token Ids
-         */
-        public int[] getTokenIds() {
-            return tokenIds;
+    protected void writeTokenTypeIds(String fieldName, XContentBuilder builder) throws IOException {
+        builder.startArray(fieldName);
+        for (int i = 0; i < tokens.size(); i++) {
+            builder.startArray();
+            for (int j = 0; j < maxLength; j++) {
+                builder.value(0);
+            }
+            builder.endArray();
         }
+        builder.endArray();
+    }
 
-        /**
-         * Maps the token position to the position in the source text.
-         * Source words may be divided into more than one token so more
-         * than one token can map back to the source token
-         *
-         * @return Map of source token to
-         */
-        public int[] getTokenMap() {
-            return tokenMap;
+    protected void writePositionIds(String fieldName, XContentBuilder builder) throws IOException {
+        builder.startArray(fieldName);
+        for (int i = 0; i < tokens.size(); i++) {
+            builder.startArray();
+            for (int j = 0; j < maxLength; j++) {
+                builder.value(j);
+            }
+            builder.endArray();
         }
+        builder.endArray();
+    }
 
-        public String getInput() {
-            return input;
-        }
+    public record Tokens(String input, List<? extends DelimitedToken> tokens, boolean truncated, int[] tokenIds, int[] tokenMap) {
 
-        public List<WordPieceTokenFilter.WordPieceToken> getTokens() {
-            return tokens;
+        public Tokens {
+            assert tokenIds.length == tokenMap.length;
         }
 
-        public boolean isTruncated() {
-            return truncated;
+        public OptionalInt getTokenIndex(int token) {
+            return IntStream.range(0, tokenIds.length).filter(tokenIndex -> token == tokenIds[tokenIndex]).findFirst();
         }
     }
+
+    interface TokensBuilder {
+        /**
+         * Adds tokens to the token builder
+         * @param tokenIds Token ids without special tokens added
+         * @param tokenMap Token map without considering special tokens
+         * @return The builder object
+         */
+        TokensBuilder addSequence(List<Integer> tokenIds, List<Integer> tokenMap);
+
+        /**
+         * Adds an encoded sequence pair to the token builder
+         * @param tokenId1s Sequence 1 ids
+         * @param tokenMap1 Sequence 1 token mappings
+         * @param tokenId2s Sequence 2 ids
+         * @param tokenMap2 Sequence 2 token map
+         * @return The builder object
+         */
+        TokensBuilder addSequencePair(List<Integer> tokenId1s, List<Integer> tokenMap1, List<Integer> tokenId2s, List<Integer> tokenMap2);
+
+        /**
+         * Builds the token object
+         * @param input the original sequence input, may be a simple concatenation of a sequence pair
+         * @param truncated Was this truncated when tokenized
+         * @param allTokens All the tokens with their values and offsets
+         * @return A new Tokens object
+         */
+        Tokens build(String input, boolean truncated, List<? extends DelimitedToken> allTokens);
+    }
 }

+ 11 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenFilter.java

@@ -19,11 +19,12 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Objects;
 
 public final class WordPieceTokenFilter extends TokenFilter {
-    protected final LinkedList<WordPieceToken> tokens;
+    private final LinkedList<WordPieceToken> tokens;
     private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
-    protected final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class);
+    private final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class);
     private final PositionIncrementAttribute posIncAtt = addAttribute(PositionIncrementAttribute.class);
     private static final CharSequence CONTINUATION = "##";
 
@@ -105,15 +106,14 @@ public final class WordPieceTokenFilter extends TokenFilter {
         if (input.incrementToken()) {
             if (neverSplit.contains(termAtt)) {
                 Integer maybeTokenized = vocabulary.get(termAtt);
-                if (maybeTokenized == null) {
-                    tokenizedValues.add(
-                        new WordPieceToken(termAtt.toString(), tokenizedUnknown, offsetAtt.startOffset(), offsetAtt.endOffset())
-                    );
-                } else {
-                    tokenizedValues.add(
-                        new WordPieceToken(termAtt.toString(), maybeTokenized, offsetAtt.startOffset(), offsetAtt.endOffset())
-                    );
-                }
+                tokenizedValues.add(
+                    new WordPieceToken(
+                        termAtt.toString(),
+                        Objects.requireNonNullElse(maybeTokenized, tokenizedUnknown),
+                        offsetAtt.startOffset(),
+                        offsetAtt.endOffset()
+                    )
+                );
                 return true;
             }
             if (termAtt.length() > maxInputCharsPerWord) {

+ 2 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java

@@ -91,14 +91,10 @@ public class PyTorchInferenceResult implements ToXContentObject {
         builder.field(REQUEST_ID.getPreferredName(), requestId);
         if (inference != null) {
             builder.startArray(INFERENCE.getPreferredName());
-            for (int i = 0; i < inference.length; i++) {
+            for (double[][] doubles : inference) {
                 builder.startArray();
                 for (int j = 0; j < inference[0].length; j++) {
-                    builder.startArray();
-                    for (int k = 0; k < inference[0][0].length; k++) {
-                        builder.value(inference[i][j][k]);
-                    }
-                    builder.endArray();
+                    builder.value(doubles[j]);
                 }
                 builder.endArray();
             }

+ 7 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertTokenizationResultTests.java

@@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerT
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 
-public class BertRequestBuilderTests extends ESTestCase {
+public class BertTokenizationResultTests extends ESTestCase {
 
     private BertTokenizer tokenizer;
 
@@ -40,9 +40,9 @@ public class BertRequestBuilderTests extends ESTestCase {
     public void testBuildRequest() throws IOException {
         tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, null, 512, null)).build();
 
-        BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
+        var requestBuilder = tokenizer.requestBuilder();
         NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput(), true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertEquals("request1", jsonDocAsMap.get("request_id"));
@@ -52,7 +52,6 @@ public class BertRequestBuilderTests extends ESTestCase {
         assertEquals(Arrays.asList(0, 1, 2, 3, 4), firstListItemFromMap("arg_3", jsonDocAsMap));
     }
 
-    @SuppressWarnings("unchecked")
     private List<Integer> firstListItemFromMap(String name, Map<String, Object> jsonDocAsMap) {
         return nthListItemFromMap(name, 0, jsonDocAsMap);
     }
@@ -65,7 +64,7 @@ public class BertRequestBuilderTests extends ESTestCase {
     public void testInputTooLarge() throws IOException {
         tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, null, 5, null)).build();
         {
-            BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
+            var requestBuilder = tokenizer.requestBuilder();
             ElasticsearchStatusException e = expectThrows(
                 ElasticsearchStatusException.class,
                 () -> requestBuilder.buildRequest(
@@ -81,7 +80,7 @@ public class BertRequestBuilderTests extends ESTestCase {
             );
         }
         {
-            BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
+            var requestBuilder = tokenizer.requestBuilder();
             // input will become 3 tokens + the Class and Separator token = 5 which is
             // our max sequence length
             requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
@@ -92,13 +91,13 @@ public class BertRequestBuilderTests extends ESTestCase {
     public void testBatchWithPadding() throws IOException {
         tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, new BertTokenization(null, null, 512, null)).build();
 
-        BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
+        var requestBuilder = tokenizer.requestBuilder();
         NlpTask.Request request = requestBuilder.buildRequest(
             List.of("Elasticsearch", "my little red car", "Godzilla day"),
             "request1",
             Tokenization.Truncate.NONE
         );
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput(), true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));

+ 15 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -14,13 +14,13 @@ import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.WordPieceTokenFilter;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.List;
 import java.util.OptionalInt;
 
@@ -40,25 +40,28 @@ public class FillMaskProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0 }, // The
                 { 0, 0, 0, 0, 0, 0, 0 }, // capital
                 { 0, 0, 0, 0, 0, 0, 0 }, // of
-                { 0.01, 0.01, 0.3, 0.1, 0.01, 0.2, 1.2 }, // MASK
+                { 0.01, 0.01, 0.3, 0.01, 0.2, 1.2, 0.1 }, // MASK
                 { 0, 0, 0, 0, 0, 0, 0 }, // is
                 { 0, 0, 0, 0, 0, 0, 0 } // paris
             } };
 
         String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris";
 
-        List<String> vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France");
+        List<String> vocab = Arrays.asList("The", "capital", "of", "is", "Paris", "France", BertTokenizer.MASK_TOKEN);
         List<WordPieceTokenFilter.WordPieceToken> tokens = List.of();
 
         int[] tokenMap = new int[] { 0, 1, 2, 3, 4, 5 };
-        int[] tokenIds = new int[] { 0, 1, 2, 3, 4, 5 };
+        int[] tokenIds = new int[] { 0, 1, 2, 6, 4, 5 };
 
-        TokenizationResult tokenization = new TokenizationResult(vocab);
-        tokenization.addTokenization(input, false, tokens, tokenIds, tokenMap);
+        TokenizationResult tokenization = new BertTokenizationResult(
+            vocab,
+            List.of(new TokenizationResult.Tokens(input, tokens, false, tokenIds, tokenMap)),
+            0
+        );
 
         BertTokenizer tokenizer = mock(BertTokenizer.class);
         when(tokenizer.getMaskToken()).thenReturn(BertTokenizer.MASK_TOKEN);
-        when(tokenizer.getMaskTokenId()).thenReturn(OptionalInt.of(3));
+        when(tokenizer.getMaskTokenId()).thenReturn(OptionalInt.of(6));
 
         String resultsField = randomAlphaOfLength(10);
         FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
@@ -84,8 +87,11 @@ public class FillMaskProcessorTests extends ESTestCase {
         BertTokenizer tokenizer = mock(BertTokenizer.class);
         when(tokenizer.getMaskToken()).thenReturn("[MASK]");
 
-        TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
-        tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {});
+        TokenizationResult tokenization = new BertTokenizationResult(
+            List.of(),
+            List.of(new TokenizationResult.Tokens("", List.of(), false, new int[0], new int[0])),
+            0
+        );
 
         PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null);
         expectThrows(

+ 7 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetRequestBuilderTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/MPNetTokenizationResultTests.java

@@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizer
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 
-public class MPNetRequestBuilderTests extends ESTestCase {
+public class MPNetTokenizationResultTests extends ESTestCase {
     private MPNetTokenizer tokenizer;
 
     @After
@@ -39,9 +39,9 @@ public class MPNetRequestBuilderTests extends ESTestCase {
     public void testBuildRequest() throws IOException {
         tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 512, null)).build();
 
-        MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+        var requestBuilder = tokenizer.requestBuilder();
         NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput(), true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(3));
         assertEquals("request1", jsonDocAsMap.get("request_id"));
@@ -49,7 +49,6 @@ public class MPNetRequestBuilderTests extends ESTestCase {
         assertEquals(Arrays.asList(1, 1, 1, 1, 1), firstListItemFromMap("arg_1", jsonDocAsMap));
     }
 
-    @SuppressWarnings("unchecked")
     private List<Integer> firstListItemFromMap(String name, Map<String, Object> jsonDocAsMap) {
         return nthListItemFromMap(name, 0, jsonDocAsMap);
     }
@@ -62,7 +61,7 @@ public class MPNetRequestBuilderTests extends ESTestCase {
     public void testInputTooLarge() throws IOException {
         tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 5, null)).build();
         {
-            MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+            var requestBuilder = tokenizer.requestBuilder();
             ElasticsearchStatusException e = expectThrows(
                 ElasticsearchStatusException.class,
                 () -> requestBuilder.buildRequest(
@@ -78,7 +77,7 @@ public class MPNetRequestBuilderTests extends ESTestCase {
             );
         }
         {
-            MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+            var requestBuilder = tokenizer.requestBuilder();
             // input will become 3 tokens + the Class and Separator token = 5 which is
             // our max sequence length
             requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
@@ -89,13 +88,13 @@ public class MPNetRequestBuilderTests extends ESTestCase {
     public void testBatchWithPadding() throws IOException {
         tokenizer = MPNetTokenizer.mpBuilder(TEST_CASED_VOCAB, new MPNetTokenization(null, null, 512, null)).build();
 
-        MPNetRequestBuilder requestBuilder = new MPNetRequestBuilder(tokenizer);
+        var requestBuilder = tokenizer.requestBuilder();
         NlpTask.Request request = requestBuilder.buildRequest(
             List.of("Elasticsearch", "my little red car", "Godzilla day"),
             "request1",
             Tokenization.Truncate.NONE
         );
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput(), true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(3));
         assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));

+ 4 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -182,7 +182,7 @@ public class NerProcessorTests extends ESTestCase {
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_ORG, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.B_ORG, 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, hasSize(3));
@@ -218,7 +218,7 @@ public class NerProcessorTests extends ESTestCase {
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.O, 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, hasSize(3));
@@ -241,7 +241,7 @@ public class NerProcessorTests extends ESTestCase {
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_PER, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_PER, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.B_ORG, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.B_ORG, 1.0));
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);
         assertThat(entityGroups, hasSize(3));
@@ -272,7 +272,7 @@ public class NerProcessorTests extends ESTestCase {
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
         taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.I_ORG, 1.0));
-        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i++), NerProcessor.IobTag.O, 1.0));
+        taggedTokens.add(new NerProcessor.NerResultProcessor.TaggedToken(tokens.get(i), NerProcessor.IobTag.O, 1.0));
         assertEquals(tokens.size(), taggedTokens.size());
 
         List<NerResults.EntityGroup> entityGroups = NerProcessor.NerResultProcessor.groupTaggedTokens(taggedTokens, input);

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -71,7 +71,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
         NlpTask.Request request = processor.getRequestBuilder(config)
             .buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
 
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput(), true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertEquals("request1", jsonDocAsMap.get("request_id"));

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java

@@ -50,7 +50,7 @@ public class ZeroShotClassificationProcessorTests extends ESTestCase {
             (NlpConfig) new ZeroShotClassificationConfigUpdate.Builder().setLabels(List.of("new", "stuff")).build().apply(config)
         ).buildRequest(List.of("Elasticsearch fun"), "request1", Tokenization.Truncate.NONE);
 
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput(), true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertEquals("request1", jsonDocAsMap.get("request_id"));

+ 61 - 61
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -47,8 +47,8 @@ public class BertTokenizerTests extends ESTestCase {
         BertTokenizer.PAD_TOKEN
     );
 
-    private List<String> tokenStrings(List<WordPieceTokenFilter.WordPieceToken> tokens) {
-        return tokens.stream().map(WordPieceTokenFilter.WordPieceToken::toString).collect(Collectors.toList());
+    private List<String> tokenStrings(List<? extends DelimitedToken> tokens) {
+        return tokens.stream().map(DelimitedToken::toString).collect(Collectors.toList());
     }
 
     public void testTokenize() {
@@ -58,10 +58,10 @@ public class BertTokenizerTests extends ESTestCase {
                 new BertTokenization(null, false, null, Tokenization.Truncate.NONE)
             ).build()
         ) {
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
-            assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", "fun"));
-            assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+            assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", "fun"));
+            assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
         }
     }
 
@@ -103,11 +103,11 @@ public class BertTokenizerTests extends ESTestCase {
             ).build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize(
                 "Elasticsearch fun with Pancake and Godzilla",
                 Tokenization.Truncate.FIRST
             );
-            assertArrayEquals(new int[] { 0, 1, 3, 18, 17 }, tokenization.getTokenIds());
+            assertArrayEquals(new int[] { 0, 1, 3, 18, 17 }, tokenization.tokenIds());
         }
 
         try (
@@ -120,16 +120,16 @@ public class BertTokenizerTests extends ESTestCase {
                 "Elasticsearch fun with Pancake and Godzilla",
                 Tokenization.Truncate.FIRST
             );
-            assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap());
+            assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.tokenMap());
         }
     }
 
     public void testTokenizeAppendSpecialTokens() {
         try (BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).build()) {
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
-            assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+            assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.tokenMap());
         }
     }
 
@@ -143,13 +143,13 @@ public class BertTokenizerTests extends ESTestCase {
                 .build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize(
                 "Elasticsearch " + specialToken + " fun",
                 Tokenization.Truncate.NONE
             );
-            assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", specialToken, "fun"));
-            assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.getTokenMap());
+            assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", specialToken, "fun"));
+            assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.tokenMap());
         }
     }
 
@@ -161,13 +161,13 @@ public class BertTokenizerTests extends ESTestCase {
             ).setDoLowerCase(false).setWithSpecialTokens(false).build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
-            assertArrayEquals(new int[] { 3, 2 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+            assertArrayEquals(new int[] { 3, 2 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 1 }, tokenization.tokenMap());
 
             tokenization = tokenizer.tokenize("elasticsearch fun", Tokenization.Truncate.NONE);
-            assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+            assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
         }
 
         try (
@@ -177,9 +177,9 @@ public class BertTokenizerTests extends ESTestCase {
             ).setDoLowerCase(true).setWithSpecialTokens(false).build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
-            assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+            assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
         }
     }
 
@@ -189,14 +189,14 @@ public class BertTokenizerTests extends ESTestCase {
                 .setWithSpecialTokens(false)
                 .build()
         ) {
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE);
-            assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", ",", "fun", "."));
-            assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE);
+            assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", ",", "fun", "."));
+            assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.tokenMap());
 
             tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].", Tokenization.Truncate.NONE);
-            assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap());
+            assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.tokenMap());
         }
     }
 
@@ -224,20 +224,20 @@ public class BertTokenizerTests extends ESTestCase {
             ).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE);
-            assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "[MASK]", "-", "ta", "##stic", "!"));
-            assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE);
+            assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "[MASK]", "-", "ta", "##stic", "!"));
+            assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.tokenMap());
 
             tokenization = tokenizer.tokenize("This is sub~[MASK]!", Tokenization.Truncate.NONE);
-            assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", "~", "[MASK]", "!"));
-            assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.getTokenMap());
+            assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "sub", "~", "[MASK]", "!"));
+            assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.tokenMap());
 
             tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!", Tokenization.Truncate.NONE);
-            assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!"));
-            assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.getTokenMap());
+            assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!"));
+            assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.tokenMap());
         }
     }
 
@@ -257,23 +257,23 @@ public class BertTokenizerTests extends ESTestCase {
                     tokenizer.tokenize("Godzilla Pancake red car day", Tokenization.Truncate.NONE)
                 )
             );
-            assertThat(tr.getTokenizations(), hasSize(4));
+            assertThat(tr.getTokens(), hasSize(4));
 
-            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-            assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tr.getTokenization(0);
+            assertArrayEquals(new int[] { 0, 1 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0 }, tokenization.tokenMap());
 
-            tokenization = tr.getTokenizations().get(1);
-            assertArrayEquals(new int[] { 4, 5, 6, 7 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.getTokenMap());
+            tokenization = tr.getTokenization(1);
+            assertArrayEquals(new int[] { 4, 5, 6, 7 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.tokenMap());
 
-            tokenization = tr.getTokenizations().get(2);
-            assertArrayEquals(new int[] { 8, 9, 16 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+            tokenization = tr.getTokenization(2);
+            assertArrayEquals(new int[] { 8, 9, 16 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
 
-            tokenization = tr.getTokenizations().get(3);
-            assertArrayEquals(new int[] { 8, 9, 17, 6, 7, 16 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap());
+            tokenization = tr.getTokenization(3);
+            assertArrayEquals(new int[] { 8, 9, 17, 6, 7, 16 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.tokenMap());
         }
     }
 
@@ -284,13 +284,13 @@ public class BertTokenizerTests extends ESTestCase {
                 .setWithSpecialTokens(true)
                 .build()
         ) {
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize(
                 "Elasticsearch is fun",
                 "Godzilla my little red car",
                 Tokenization.Truncate.NONE
             );
 
-            var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
+            var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
             assertThat(
                 tokenStream,
                 contains(
@@ -309,7 +309,7 @@ public class BertTokenizerTests extends ESTestCase {
                     BertTokenizer.SEPARATOR_TOKEN
                 )
             );
-            assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.getTokenIds());
+            assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.tokenIds());
         }
     }
 
@@ -321,13 +321,13 @@ public class BertTokenizerTests extends ESTestCase {
             ).build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize(
                 "Elasticsearch is fun",
                 "Godzilla my little red car",
                 Tokenization.Truncate.FIRST
             );
 
-            var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
+            var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
             assertThat(
                 tokenStream,
                 contains(
@@ -359,12 +359,12 @@ public class BertTokenizerTests extends ESTestCase {
             ).build()
         ) {
 
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize(
                 "Elasticsearch is fun",
                 "Godzilla my little red car",
                 Tokenization.Truncate.SECOND
             );
-            var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
+            var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
             assertThat(
                 tokenStream,
                 contains(

+ 9 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/MPNetTokenizerTests.java

@@ -42,8 +42,8 @@ public class MPNetTokenizerTests extends ESTestCase {
         MPNetTokenizer.PAD_TOKEN
     );
 
-    private List<String> tokenStrings(List<WordPieceTokenFilter.WordPieceToken> tokens) {
-        return tokens.stream().map(WordPieceTokenFilter.WordPieceToken::toString).collect(Collectors.toList());
+    private List<String> tokenStrings(List<? extends DelimitedToken> tokens) {
+        return tokens.stream().map(DelimitedToken::toString).collect(Collectors.toList());
     }
 
     public void testTokenize() {
@@ -53,10 +53,10 @@ public class MPNetTokenizerTests extends ESTestCase {
                 new MPNetTokenization(null, false, null, Tokenization.Truncate.NONE)
             ).build()
         ) {
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
-            assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", "fun"));
-            assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds());
-            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
+            assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", "fun"));
+            assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.tokenIds());
+            assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
         }
     }
 
@@ -67,13 +67,13 @@ public class MPNetTokenizerTests extends ESTestCase {
                 new MPNetTokenization(null, false, null, Tokenization.Truncate.NONE)
             ).setDoLowerCase(false).setWithSpecialTokens(true).build()
         ) {
-            TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize(
                 "Elasticsearch is fun",
                 "Godzilla my little red car",
                 Tokenization.Truncate.NONE
             );
 
-            var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
+            var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
             assertThat(
                 tokenStream,
                 contains(
@@ -93,7 +93,7 @@ public class MPNetTokenizerTests extends ESTestCase {
                     MPNetTokenizer.SEPARATOR_TOKEN
                 )
             );
-            assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.getTokenIds());
+            assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.tokenIds());
         }
     }