瀏覽代碼

[ML] Multiple items in a single inference request (#75759)

Inference requests can be batched by adding more rows to the input tensor. 
These batch calls are more performant than making multiple calls to forward()
with a single input when all the inputs are of a similar length. The expected 
input is now a 2D array of tokens and 2D arrays of supporting arguments, 
the output is a 3D array.
David Kyle 4 年之前
父節點
當前提交
7a283104c5
共有 22 個文件被更改,包括 515 次插入208 次删除
  1. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java
  2. 54 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java
  3. 10 7
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  4. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  5. 26 20
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/PyTorchResult.java
  6. 20 17
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  7. 14 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilder.java
  8. 17 13
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  9. 10 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  10. 52 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  11. 5 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  12. 5 30
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  13. 30 9
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  14. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  15. 61 34
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java
  16. 5 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/PyTorchResultTests.java
  17. 64 10
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java
  18. 44 7
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilderTests.java
  19. 9 10
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  20. 6 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  21. 13 10
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  22. 56 8
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -68,8 +68,8 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return builder.build();
         }
 
-        private String deploymentId;
-        private List<Map<String, Object>> docs;
+        private final String deploymentId;
+        private final List<Map<String, Object>> docs;
 
         public Request(String deploymentId, List<Map<String, Object>> docs) {
             this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, DEPLOYMENT_ID);

+ 54 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlParserUtils.java

@@ -50,4 +50,58 @@ public final class MlParserUtils {
         }
         return values;
     }
+
+    /**
+     * Parses a 3 dimensional array of doubles.
+     *
+     * @param fieldName the field name
+     * @param parser the outer parser
+     * @return The 3D array of doubles
+     * @throws IOException If parsing fails
+     */
+    public static double[][][] parse3DArrayOfDoubles(String fieldName, XContentParser parser) throws IOException {
+        if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
+            throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
+        }
+        List<List<List<Double>>> values = new ArrayList<>();
+        while(parser.nextToken() != XContentParser.Token.END_ARRAY) {
+            if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
+                throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
+            }
+
+            List<List<Double>> innerList = new ArrayList<>();
+
+            while(parser.nextToken() != XContentParser.Token.END_ARRAY) {
+                if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
+                    throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
+                }
+
+                if (parser.currentToken() != XContentParser.Token.START_ARRAY) {
+                    throw new IllegalArgumentException("unexpected token [" + parser.currentToken() + "] for [" + fieldName + "]");
+                }
+
+                List<Double> innerInner = new ArrayList<>();
+                while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
+                    if (parser.currentToken() != XContentParser.Token.VALUE_NUMBER) {
+                        throw new IllegalStateException("expected non-null numerical value but got [" + parser.currentToken() + "] " +
+                            "for [" + fieldName + "]");
+                    }
+                    innerInner.add(parser.doubleValue());
+                }
+                innerList.add(innerInner);
+            }
+            values.add(innerList);
+        }
+
+        double [][][] val = new double[values.size()][values.get(0).size()][values.get(0).get(0).size()];
+
+        for (int i = 0; i < val.length; i++) {
+            for (int j = 0; j < val[0].length; j++) {
+                double[] doubles = values.get(i).get(j).stream().mapToDouble(d -> d).toArray();
+                System.arraycopy(doubles, 0, val[i][j], 0, doubles.length);
+            }
+        }
+
+        return val;
+    }
 }

+ 10 - 7
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -18,10 +18,12 @@ import org.elasticsearch.test.rest.ESRestTestCase;
 import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
 import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Base64;
 import java.util.List;
 import java.util.Map;
@@ -109,7 +111,8 @@ public class PyTorchModelIT extends ESRestTestCase {
             "{" +
             "\"transient\" : {\n" +
             "        \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" +
-            "        \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" +
+            "        \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\",\n" +
+            "        \"logger.org.elasticsearch.xpack.ml.process.logging\" : \"TRACE\"\n" +
             "    }" +
             "}");
         client().performRequest(loggingSettings);
@@ -124,7 +127,8 @@ public class PyTorchModelIT extends ESRestTestCase {
             "{" +
             "\"transient\" : {\n" +
             "        \"logger.org.elasticsearch.xpack.ml.inference.allocation\" :null,\n" +
-            "        \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null\n" +
+            "        \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null,\n" +
+            "        \"logger.org.elasticsearch.xpack.ml.process.logging\" : null\n" +
             "    }" +
             "}");
         client().performRequest(loggingSettings);
@@ -133,7 +137,6 @@ public class PyTorchModelIT extends ESRestTestCase {
         waitForPendingTasks(adminClient());
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
     public void testEvaluate() throws IOException, InterruptedException {
         String modelId = "test_evaluate";
         createModelStoreIndex();
@@ -168,7 +171,6 @@ public class PyTorchModelIT extends ESRestTestCase {
     }
 
     @SuppressWarnings("unchecked")
-    @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
     public void testLiveDeploymentStats() throws IOException {
         String modelA = "model_a";
 
@@ -193,7 +195,6 @@ public class PyTorchModelIT extends ESRestTestCase {
     }
 
     @SuppressWarnings("unchecked")
-    @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
     public void testGetDeploymentStats_WithWildcard() throws IOException {
 
         {
@@ -262,7 +263,6 @@ public class PyTorchModelIT extends ESRestTestCase {
     }
 
     @SuppressWarnings("unchecked")
-    @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
     public void testGetDeploymentStats_WithStartedStoppedDeployments() throws IOException {
         putVocabulary(List.of("once", "twice"));
         String modelFoo = "foo";
@@ -367,7 +367,10 @@ public class PyTorchModelIT extends ESRestTestCase {
     }
 
     private void putVocabulary(List<String> vocabulary) throws IOException {
-        String quotedWords = vocabulary.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
+        List<String> vocabularyWithPad = new ArrayList<>();
+        vocabularyWithPad.add(BertTokenizer.PAD_TOKEN);
+        vocabularyWithPad.addAll(vocabulary);
+        String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
 
         Request request = new Request("PUT", "/" + VOCAB_INDEX + "/_doc/test_vocab");
         request.setJsonEntity("{  " +

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -49,6 +49,8 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
@@ -232,7 +234,10 @@ public class DeploymentManager {
             @Override
             protected void doRun() {
                 try {
-                    String text = NlpTask.extractInput(processContext.modelInput.get(), doc);
+                    // The request builder expect a list of inputs which are then batched.
+                    // TODO batching was implemented for expected use-cases such as zero-shot
+                    // classification but is not used here.
+                    List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
                     NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
                     processor.validateInputs(text);
                     NlpTask.Request request = processor.getRequestBuilder().buildRequest(text, requestId);

+ 26 - 20
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/PyTorchResult.java

@@ -7,21 +7,20 @@
 
 package org.elasticsearch.xpack.ml.inference.deployment;
 
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;
 
 import java.io.IOException;
 import java.util.Arrays;
-import java.util.List;
 import java.util.Objects;
 
 /**
@@ -37,21 +36,13 @@ public class PyTorchResult implements ToXContentObject, Writeable {
     private static final ParseField TIME_MS = new ParseField("time_ms");
 
     public static final ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>("pytorch_result",
-        a -> new PyTorchResult((String) a[0], (double[][]) a[1], (Long) a[2], (String) a[3]));
+        a -> new PyTorchResult((String) a[0], (double[][][]) a[1], (Long) a[2], (String) a[3]));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID);
         PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(),
-            (p, c) -> {
-                List<List<Double>> listOfListOfDoubles = MlParserUtils.parseArrayOfArrays(
-                    INFERENCE.getPreferredName(), XContentParser::doubleValue, p);
-                double[][] primitiveDoubles = new double[listOfListOfDoubles.size()][];
-                for (int i = 0; i < listOfListOfDoubles.size(); i++) {
-                    List<Double> row = listOfListOfDoubles.get(i);
-                    primitiveDoubles[i] = row.stream().mapToDouble(d -> d).toArray();
-                }
-                return primitiveDoubles;
-            },
+            (p, c) ->
+                MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p),
             INFERENCE,
             ObjectParser.ValueType.VALUE_ARRAY
         );
@@ -64,12 +55,12 @@ public class PyTorchResult implements ToXContentObject, Writeable {
     }
 
     private final String requestId;
-    private final double[][] inference;
+    private final double[][][] inference;
     private final Long timeMs;
     private final String error;
 
     public PyTorchResult(String requestId,
-                         @Nullable double[][] inference,
+                         @Nullable double[][][] inference,
                          @Nullable Long timeMs,
                          @Nullable String error) {
         this.requestId = Objects.requireNonNull(requestId);
@@ -82,7 +73,7 @@ public class PyTorchResult implements ToXContentObject, Writeable {
         requestId = in.readString();
         boolean hasInference = in.readBoolean();
         if (hasInference) {
-            inference = in.readArray(StreamInput::readDoubleArray, double[][]::new);
+            inference = in.readArray(in2 -> in2.readArray(StreamInput::readDoubleArray, double[][]::new), double[][][]::new);
         } else {
             inference = null;
         }
@@ -102,7 +93,7 @@ public class PyTorchResult implements ToXContentObject, Writeable {
         return error;
     }
 
-    public double[][] getInferenceResult() {
+    public double[][][] getInferenceResult() {
         return inference;
     }
 
@@ -115,7 +106,20 @@ public class PyTorchResult implements ToXContentObject, Writeable {
         builder.startObject();
         builder.field(REQUEST_ID.getPreferredName(), requestId);
         if (inference != null) {
-            builder.field(INFERENCE.getPreferredName(), inference);
+            builder.startArray(INFERENCE.getPreferredName());
+            for (int i = 0; i < inference.length; i++) {
+                builder.startArray();
+                for (int j = 0; j < inference[0].length; j++)
+                {
+                    builder.startArray();
+                    for (int k = 0; k < inference[0][0].length; k++) {
+                        builder.value(inference[i][j][k]);
+                    }
+                    builder.endArray();
+                }
+                builder.endArray();
+            }
+            builder.endArray();
         }
         if (timeMs != null) {
             builder.field(TIME_MS.getPreferredName(), timeMs);
@@ -134,7 +138,9 @@ public class PyTorchResult implements ToXContentObject, Writeable {
             out.writeBoolean(false);
         } else {
             out.writeBoolean(true);
-            out.writeArray(StreamOutput::writeDoubleArray, inference);
+            out.writeArray(
+                (out2, arr) -> out2.writeArray(StreamOutput::writeDoubleArray, arr),
+                inference);
         }
         out.writeOptionalLong(timeMs);
         out.writeOptionalString(error);

+ 20 - 17
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -14,7 +14,7 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
-import java.util.Arrays;
+import java.util.List;
 
 public class BertRequestBuilder implements NlpTask.RequestBuilder {
 
@@ -31,30 +31,33 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     }
 
     @Override
-    public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
-        TokenizationResult tokenization = tokenizer.tokenize(input);
-        return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
+    public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
+        if (tokenizer.getPadToken().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
+                " token in its vocabulary");
+        }
+
+        TokenizationResult tokenization = tokenizer.tokenize(inputs);
+        return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId));
     }
 
-    static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
+    static BytesReference jsonRequest(TokenizationResult tokenization,
+                                      int padToken,
+                                      String requestId) throws IOException {
         XContentBuilder builder = XContentFactory.jsonBuilder();
         builder.startObject();
         builder.field(REQUEST_ID, requestId);
-        builder.array(TOKENS, tokens);
-
-        int[] inputMask = new int[tokens.length];
-        Arrays.fill(inputMask, 1);
-        int[] segmentMask = new int[tokens.length];
-        Arrays.fill(segmentMask, 0);
-        int[] positionalIds = new int[tokens.length];
-        Arrays.setAll(positionalIds, i -> i);
-
-        builder.array(ARG1, inputMask);
-        builder.array(ARG2, segmentMask);
-        builder.array(ARG3, positionalIds);
+
+        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
+        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
+        int batchSize = tokenization.getTokenizations().size();
+        NlpTask.RequestBuilder.writeNonPaddedArguments(ARG2, batchSize, tokenization.getLongestSequenceLength(), i -> 0, builder);
+        NlpTask.RequestBuilder.writeNonPaddedArguments(ARG3, batchSize, tokenization.getLongestSequenceLength(), i -> i, builder);
         builder.endObject();
 
         // BytesReference.bytes closes the builder
         return BytesReference.bytes(builder);
     }
+
+
 }

+ 14 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilder.java

@@ -14,7 +14,7 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
-import java.util.Arrays;
+import java.util.List;
 
 public class DistilBertRequestBuilder implements NlpTask.RequestBuilder {
 
@@ -29,21 +29,24 @@ public class DistilBertRequestBuilder implements NlpTask.RequestBuilder {
     }
 
     @Override
-    public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
-        TokenizationResult result = tokenizer.tokenize(input);
-        return new NlpTask.Request(result, jsonRequest(result.getTokenIds(), requestId));
+    public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
+        if (tokenizer.getPadToken().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
+                " token in its vocabulary");
+        }
+
+        TokenizationResult result = tokenizer.tokenize(inputs);
+        return new NlpTask.Request(result, jsonRequest(result, tokenizer.getPadToken().getAsInt(), requestId));
     }
 
-    static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
+    static BytesReference jsonRequest(TokenizationResult tokenization,
+                                      int padToken,
+                                      String requestId) throws IOException {
         XContentBuilder builder = XContentFactory.jsonBuilder();
         builder.startObject();
         builder.field(REQUEST_ID, requestId);
-        builder.array(TOKENS, tokens);
-
-        int[] inputMask = new int[tokens.length];
-        Arrays.fill(inputMask, 1);
-
-        builder.array(ARG1, inputMask);
+        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
+        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
         builder.endObject();
 
         // BytesReference.bytes closes the builder

+ 17 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -30,19 +30,21 @@ public class FillMaskProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public void validateInputs(String inputs) {
-        if (inputs.isBlank()) {
+    public void validateInputs(List<String> inputs) {
+        if (inputs.isEmpty()) {
             throw new IllegalArgumentException("input request is empty");
         }
 
-        int maskIndex = inputs.indexOf(BertTokenizer.MASK_TOKEN);
-        if (maskIndex < 0) {
-            throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found");
-        }
+        for (String input : inputs) {
+            int maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN);
+            if (maskIndex < 0) {
+                throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found");
+            }
 
-        maskIndex = inputs.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length());
-        if (maskIndex > 0) {
-            throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input");
+            maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length());
+            if (maskIndex > 0) {
+                throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input");
+            }
         }
     }
 
@@ -58,18 +60,20 @@ public class FillMaskProcessor implements NlpTask.Processor {
 
     InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
 
-        if (tokenization.getTokens().isEmpty()) {
+        if (tokenization.getTokenizations().isEmpty() ||
+            tokenization.getTokenizations().get(0).getTokens().isEmpty()) {
             return new FillMaskResults(Collections.emptyList());
         }
 
-        int maskTokenIndex = tokenization.getTokens().indexOf(BertTokenizer.MASK_TOKEN);
-        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[maskTokenIndex]);
+        int maskTokenIndex = tokenization.getTokenizations().get(0).getTokens().indexOf(BertTokenizer.MASK_TOKEN);
+        // TODO - process all results in the batch
+        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
 
         NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(NUM_RESULTS, normalizedScores);
         List<FillMaskResults.Prediction> results = new ArrayList<>(NUM_RESULTS);
         for (NlpHelpers.ScoreAndIndex scoreAndIndex : scoreAndIndices) {
             String predictedToken = tokenization.getFromVocab(scoreAndIndex.index);
-            String sequence = tokenization.getInput().replace(BertTokenizer.MASK_TOKEN, predictedToken);
+            String sequence = tokenization.getTokenizations().get(0).getInput().replace(BertTokenizer.MASK_TOKEN, predictedToken);
             results.add(new FillMaskResults.Prediction(predictedToken, scoreAndIndex.score, sequence));
         }
         return new FillMaskResults(results);

+ 10 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -119,7 +119,7 @@ public class NerProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public void validateInputs(String inputs) {
+    public void validateInputs(List<String> inputs) {
         // No validation
     }
 
@@ -142,17 +142,20 @@ public class NerProcessor implements NlpTask.Processor {
 
         @Override
         public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
-            if (tokenization.getTokens().isEmpty()) {
+            if (tokenization.getTokenizations().isEmpty() ||
+                tokenization.getTokenizations().get(0).getTokens().isEmpty()) {
                 return new NerResults(Collections.emptyList());
             }
+            // TODO - process all results in the batch
+
             // TODO It might be best to do the soft max after averaging scores for
             // sub-tokens. If we had a word that is "elastic" which is tokenized to
             // "el" and "astic" then perhaps we get a prediction for org of 10 for "el"
             // and -5 for "astic". Averaging after softmax would produce a prediction
             // of maybe (1 + 0) / 2 = 0.5 while before softmax it'd be exp(10 - 5) / normalization
             // which could easily be close to 1.
-            double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult());
-            List<TaggedToken> taggedTokens = tagTokens(tokenization, normalizedScores);
+            double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
+            List<TaggedToken> taggedTokens = tagTokens(tokenization.getTokenizations().get(0), normalizedScores, iobMap);
             List<NerResults.EntityGroup> entities = groupTaggedTokens(taggedTokens);
             return new NerResults(entities);
         }
@@ -163,7 +166,9 @@ public class NerProcessor implements NlpTask.Processor {
          * in the original input replacing them with a single token that
          * gets labelled based on the average score of all its sub-tokens.
          */
-        private List<TaggedToken> tagTokens(TokenizationResult tokenization, double[][] scores) {
+        static List<TaggedToken> tagTokens(TokenizationResult.Tokenization tokenization,
+                                           double[][] scores,
+                                           IobTag[] iobMap) {
             List<TaggedToken> taggedTokens = new ArrayList<>();
             int startTokenIndex = 0;
             while (startTokenIndex < tokenization.getTokens().size()) {

+ 52 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -9,16 +9,18 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
-import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
@@ -34,7 +36,7 @@ public class NlpTask {
 
     /**
      * Create and validate the NLP Processor
-     * @return
+     * @return the processor based on task type
      * @throws ValidationException if the validation fails
      */
     public Processor createProcessor() throws ValidationException {
@@ -42,7 +44,53 @@ public class NlpTask {
     }
 
     public interface RequestBuilder {
-        Request buildRequest(String inputs, String requestId) throws IOException;
+        @FunctionalInterface
+        interface IntToIntFunction {
+            int applyAsInt(int value);
+        }
+
+        @FunctionalInterface
+        interface TokenLookupFunction {
+            int apply(TokenizationResult.Tokenization tokenization, int index);
+        }
+
+        Request buildRequest(List<String> inputs, String requestId) throws IOException;
+
+        static void writePaddedTokens(String fieldName,
+                                      TokenizationResult tokenization,
+                                      int padToken,
+                                      TokenLookupFunction generator,
+                                      XContentBuilder builder) throws IOException {
+            builder.startArray(fieldName);
+            for (var inputTokens : tokenization.getTokenizations()) {
+                builder.startArray();
+                int i = 0;
+                for (; i < inputTokens.getTokenIds().length; i++) {
+                    builder.value(generator.apply(inputTokens, i));
+                }
+
+                for (; i < tokenization.getLongestSequenceLength(); i++) {
+                    builder.value(padToken);
+                }
+                builder.endArray();
+            }
+            builder.endArray();
+        }
+
+        static void writeNonPaddedArguments(String fieldName,
+                                            int numTokenizations, int longestSequenceLength,
+                                            IntToIntFunction generator,
+                                            XContentBuilder builder) throws IOException {
+            builder.startArray(fieldName);
+            for (int i = 0; i < numTokenizations; i++) {
+                builder.startArray();
+                for (int j = 0; j < longestSequenceLength; j++) {
+                    builder.value(generator.applyAsInt(j));
+                }
+                builder.endArray();
+            }
+            builder.endArray();
+        }
     }
 
     public interface ResultProcessor {
@@ -60,7 +108,7 @@ public class NlpTask {
          *
          * @param inputs Text to validate
          */
-        void validateInputs(String inputs);
+        void validateInputs(List<String> inputs);
 
         RequestBuilder getRequestBuilder();
         ResultProcessor getResultProcessor();

+ 5 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -14,6 +14,8 @@ import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
+import java.util.List;
+
 /**
  * A NLP processor that directly returns the PyTorch result
  * without any post-processing
@@ -27,7 +29,7 @@ public class PassThroughProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public void validateInputs(String inputs) {
+    public void validateInputs(List<String> inputs) {
         // nothing to validate
     }
 
@@ -42,6 +44,7 @@ public class PassThroughProcessor implements NlpTask.Processor {
     }
 
     private static InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
-        return new PyTorchPassThroughResults(pyTorchResult.getInferenceResult());
+        // TODO - process all results in the batch
+        return new PyTorchPassThroughResults(pyTorchResult.getInferenceResult()[0]);
     }
 }

+ 5 - 30
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -9,9 +9,6 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@@ -21,8 +18,6 @@ import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
-import java.io.IOException;
-import java.util.Arrays;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Locale;
@@ -31,12 +26,12 @@ import java.util.stream.IntStream;
 
 public class TextClassificationProcessor implements NlpTask.Processor {
 
-    private final NlpTokenizer tokenizer;
+    private final NlpTask.RequestBuilder requestBuilder;
     private final String[] classLabels;
     private final int numTopClasses;
 
     TextClassificationProcessor(NlpTokenizer tokenizer, TextClassificationConfig config) {
-        this.tokenizer = tokenizer;
+        this.requestBuilder = tokenizer.requestBuilder();
         List<String> classLabels = config.getClassificationLabels();
         if (classLabels == null || classLabels.isEmpty()) {
             this.classLabels = new String[] {"negative", "positive"};
@@ -73,18 +68,13 @@ public class TextClassificationProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public void validateInputs(String inputs) {
+    public void validateInputs(List<String> inputs) {
         // nothing to validate
     }
 
     @Override
     public NlpTask.RequestBuilder getRequestBuilder() {
-        return this::buildRequest;
-    }
-
-    NlpTask.Request buildRequest(String input, String requestId) throws IOException {
-        TokenizationResult tokenization = tokenizer.tokenize(input);
-        return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
+        return requestBuilder;
     }
 
     @Override
@@ -105,7 +95,7 @@ public class TextClassificationProcessor implements NlpTask.Processor {
             );
         }
 
-        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
+        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][0]);
         return new TextClassificationResults(
             IntStream.range(0, normalizedScores.length)
                 .mapToObj(i -> new TopClassEntry(classLabels[i], normalizedScores[i]))
@@ -115,19 +105,4 @@ public class TextClassificationProcessor implements NlpTask.Processor {
                 .collect(Collectors.toList())
         );
     }
-
-    static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
-        XContentBuilder builder = XContentFactory.jsonBuilder();
-        builder.startObject();
-        builder.field(BertRequestBuilder.REQUEST_ID, requestId);
-        builder.array(BertRequestBuilder.TOKENS, tokens);
-
-        int[] inputMask = new int[tokens.length];
-        Arrays.fill(inputMask, 1);
-        builder.array(BertRequestBuilder.ARG1, inputMask);
-        builder.endObject();
-
-        // BytesReference.bytes closes the builder
-        return BytesReference.bytes(builder);
-    }
 }

+ 30 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -13,10 +13,9 @@ import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
+import java.util.OptionalInt;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
@@ -42,7 +41,7 @@ public class BertTokenizer implements NlpTokenizer {
 
     public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
 
-    private final Set<String> NEVER_SPLIT = new HashSet<>(Arrays.asList(MASK_TOKEN));
+    private final Set<String> NEVER_SPLIT =  Set.of(MASK_TOKEN);
 
     private final WordPieceTokenizer wordPieceTokenizer;
     private final List<String> originalVocab;
@@ -78,16 +77,28 @@ public class BertTokenizer implements NlpTokenizer {
     }
 
     /**
-     * Tokenize the input according to the basic tokenization options
-     * then perform Word Piece tokenization with the given vocabulary.
+     * Tokenize the list of inputs according to the basic tokenization
+     * options then perform Word Piece tokenization with the given vocabulary.
      *
      * The result is the Word Piece tokens, a map of the Word Piece
-     * token position to the position of the token in the source
+     * token position to the position of the token in the source for
+     * each input string grouped into a {@link Tokenization}.
+     *
      * @param text Text to tokenize
-     * @return Tokenized text, token Ids and map
+     * @return A {@link Tokenization}
      */
     @Override
-    public TokenizationResult tokenize(String text) {
+    public TokenizationResult tokenize(List<String> text) {
+        TokenizationResult tokenization = new TokenizationResult(originalVocab);
+
+        for (String input: text) {
+            addTokenization(tokenization, input);
+        }
+        return tokenization;
+    }
+
+
+    private void addTokenization(TokenizationResult tokenization, String text) {
         BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
 
         List<String> delineatedTokens = basicTokenizer.tokenize(text);
@@ -145,7 +156,17 @@ public class BertTokenizer implements NlpTokenizer {
             );
         }
 
-        return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
+        tokenization.addTokenization(text, tokens, tokenIds, tokenMap);
+    }
+
+    @Override
+    public OptionalInt getPadToken() {
+        Integer pad = vocab.get(PAD_TOKEN);
+        if (pad != null) {
+            return OptionalInt.of(pad);
+        } else {
+            return OptionalInt.empty();
+        }
     }
 
     @Override

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -16,15 +16,20 @@ import org.elasticsearch.xpack.ml.inference.nlp.DistilBertRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 
+import java.util.List;
+import java.util.OptionalInt;
+
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.TOKENIZATION;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.VOCABULARY;
 
 public interface NlpTokenizer {
 
-    TokenizationResult tokenize(String text);
+    TokenizationResult tokenize(List<String> text);
 
     NlpTask.RequestBuilder requestBuilder();
 
+    OptionalInt getPadToken();
+
     static NlpTokenizer build(Vocabulary vocabulary, Tokenization params) {
         ExceptionsHelper.requireNonNull(params, TOKENIZATION);
         ExceptionsHelper.requireNonNull(vocabulary, VOCABULARY);

+ 61 - 34
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -7,57 +7,84 @@
 
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
+import java.util.ArrayList;
 import java.util.List;
 
 public class TokenizationResult {
 
-    String input;
-    final List<String> vocab;
-    private final List<String> tokens;
-    private final int [] tokenIds;
-    private final int [] tokenMap;
+    private final List<String> vocab;
+    private final List<Tokenization> tokenizations = new ArrayList<>();
+    private int maxLength;
 
-    public TokenizationResult(String input, List<String> vocab, List<String> tokens, int[] tokenIds, int[] tokenMap) {
-        assert tokens.size() == tokenIds.length;
-        assert tokenIds.length == tokenMap.length;
-        this.input = input;
+    public TokenizationResult(List<String> vocab) {
         this.vocab = vocab;
-        this.tokens = tokens;
-        this.tokenIds = tokenIds;
-        this.tokenMap = tokenMap;
+        this.maxLength = -1;
     }
 
     public String getFromVocab(int tokenId) {
         return vocab.get(tokenId);
     }
 
-    /**
-     * The token strings from the tokenization process
-     * @return A list of tokens
-     */
-    public List<String> getTokens() {
-        return tokens;
+    public List<Tokenization> getTokenizations() {
+        return tokenizations;
     }
 
-    /**
-     * The integer values of the tokens in {@link #getTokens()}
-     * @return A list of token Ids
-     */
-    public int[] getTokenIds() {
-        return tokenIds;
+    public void addTokenization(String input, List<String> tokens, int[] tokenIds, int[] tokenMap) {
+        maxLength = Math.max(maxLength, tokenIds.length);
+        tokenizations.add(new Tokenization(input, tokens, tokenIds, tokenMap));
     }
 
-    /**
-     * Maps the token position to the position in the source text.
-     * Source words may be divided into more than one token so more
-     * than one token can map back to the source token
-     * @return Map of source token to
-     */
-    public int[] getTokenMap() {
-        return tokenMap;
+    public int getLongestSequenceLength() {
+        return maxLength;
     }
 
-    public String getInput() {
-        return input;
+    public static class Tokenization {
+
+        String input;
+        private final List<String> tokens;
+        private final int[] tokenIds;
+        private final int[] tokenMap;
+
+        public Tokenization(String input, List<String> tokens, int[] tokenIds, int[] tokenMap) {
+            assert tokens.size() == tokenIds.length;
+            assert tokenIds.length == tokenMap.length;
+            this.input = input;
+            this.tokens = tokens;
+            this.tokenIds = tokenIds;
+            this.tokenMap = tokenMap;
+        }
+
+        /**
+         * The token strings from the tokenization process
+         *
+         * @return A list of tokens
+         */
+        public List<String> getTokens() {
+            return tokens;
+        }
+
+        /**
+         * The integer values of the tokens in {@link #getTokens()}
+         *
+         * @return A list of token Ids
+         */
+        public int[] getTokenIds() {
+            return tokenIds;
+        }
+
+        /**
+         * Maps the token position to the position in the source text.
+         * Source words may be divided into more than one token so more
+         * than one token can map back to the source token
+         *
+         * @return Map of source token to
+         */
+        public int[] getTokenMap() {
+            return tokenMap;
+        }
+
+        public String getInput() {
+            return input;
+        }
     }
 }

+ 5 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/PyTorchResultTests.java

@@ -33,10 +33,13 @@ public class PyTorchResultTests extends AbstractSerializingTestCase<PyTorchResul
         } else {
             int rows = randomIntBetween(1, 10);
             int columns = randomIntBetween(1, 10);
-            double [][] arr = new double[rows][columns];
+            int depth = randomIntBetween(1, 10);
+            double [][][] arr = new double[rows][columns][depth];
             for (int i=0; i<rows; i++) {
                 for (int j=0; j<columns; j++) {
-                    arr[i][j] = randomDouble();
+                    for (int k=0; k<depth; k++) {
+                        arr[i][j][k] = randomDouble();
+                    }
                 }
             }
             return new PyTorchResult(id, arr, randomLong(), null);

+ 64 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java

@@ -16,6 +16,8 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.Matchers.containsString;
@@ -25,32 +27,42 @@ public class BertRequestBuilderTests extends ESTestCase {
 
     public void testBuildRequest() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
             new BertTokenization(null, null, 512)
         ).build();
 
         BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
-        NlpTask.Request request = requestBuilder.buildRequest("Elasticsearch fun", "request1");
-
+        NlpTask.Request request = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1");
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertEquals("request1", jsonDocAsMap.get("request_id"));
-        assertEquals(Arrays.asList(3, 0, 1, 2, 4), jsonDocAsMap.get("tokens"));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 1), jsonDocAsMap.get("arg_1"));
-        assertEquals(Arrays.asList(0, 0, 0, 0, 0), jsonDocAsMap.get("arg_2"));
-        assertEquals(Arrays.asList(0, 1, 2, 3, 4), jsonDocAsMap.get("arg_3"));
+        assertEquals(Arrays.asList(3, 0, 1, 2, 4), firstListItemFromMap("tokens", jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), firstListItemFromMap("arg_1", jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 0, 0, 0, 0), firstListItemFromMap("arg_2", jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 1, 2, 3, 4), firstListItemFromMap("arg_3", jsonDocAsMap));
+    }
+
+    @SuppressWarnings("unchecked")
+    private List<Integer> firstListItemFromMap(String name, Map<String, Object> jsonDocAsMap) {
+        return nthListItemFromMap(name, 0, jsonDocAsMap);
+    }
+
+    @SuppressWarnings("unchecked")
+    public static List<Integer> nthListItemFromMap(String name, int n, Map<String, Object> jsonDocAsMap) {
+        return ((List<List<Integer>>)jsonDocAsMap.get(name)).get(n);
     }
 
     public void testInputTooLarge() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
             new BertTokenization(null, null, 5)
         ).build();
         {
             BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
             ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-                () -> requestBuilder.buildRequest("Elasticsearch fun Elasticsearch fun Elasticsearch fun", "request1"));
+                () -> requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun Elasticsearch fun Elasticsearch fun"),
+                    "request1"));
 
             assertThat(e.getMessage(),
                 containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]"));
@@ -59,7 +71,49 @@ public class BertRequestBuilderTests extends ESTestCase {
             BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
             // input will become 3 tokens + the Class and Separator token = 5 which is
             // our max sequence length
-            requestBuilder.buildRequest("Elasticsearch fun", "request1");
+            requestBuilder.buildRequest(Collections.singletonList("Elasticsearch fun"), "request1");
         }
     }
+
+    @SuppressWarnings("unchecked")
+    public void testBatchWithPadding() throws IOException {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList(BertTokenizer.PAD_TOKEN, BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN,
+                "Elastic", "##search", "fun",
+                "Pancake", "day",
+                "my", "little", "red", "car",
+                "God", "##zilla"
+                ),
+            new BertTokenization(null, null, 512)
+        ).build();
+
+        BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer);
+        NlpTask.Request request = requestBuilder.buildRequest(
+            List.of("Elasticsearch",
+                "my little red car",
+                "Godzilla day"), "request1");
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(5));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_1"), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_2"), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_3"), hasSize(3));
+
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(1, 3, 4, 2, 0, 0), nthListItemFromMap("tokens", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 0, 0), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 0, 0, 0, 0, 0), nthListItemFromMap("arg_2", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 1, 2, 3, 4, 5), nthListItemFromMap("arg_3", 0, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(1, 8, 9, 10, 11, 2), nthListItemFromMap("tokens", 1, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 1, jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 0, 0, 0, 0, 0), nthListItemFromMap("arg_2", 1, jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 1, 2, 3, 4, 5), nthListItemFromMap("arg_3", 1, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(1, 12, 13, 7, 2, 0), nthListItemFromMap("tokens", 2, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 0), nthListItemFromMap("arg_1", 2, jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 0, 0, 0, 0, 0), nthListItemFromMap("arg_2", 2, jsonDocAsMap));
+        assertEquals(Arrays.asList(0, 1, 2, 3, 4, 5), nthListItemFromMap("arg_3", 2, jsonDocAsMap));
+    }
 }

+ 44 - 7
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/DistilBertRequestBuilderTests.java

@@ -12,13 +12,16 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilderTests.nthListItemFromMap;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 
@@ -26,30 +29,30 @@ public class DistilBertRequestBuilderTests extends ESTestCase {
 
     public void testBuildRequest() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
             new DistilBertTokenization(null, null, 512)
         ).build();
 
         DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
-        BytesReference bytesReference = requestBuilder.buildRequest("Elasticsearch fun", "request1").processInput;
+        BytesReference bytesReference = requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1").processInput;
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(3));
         assertEquals("request1", jsonDocAsMap.get("request_id"));
-        assertEquals(Arrays.asList(3, 0, 1, 2, 4), jsonDocAsMap.get("tokens"));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 1), jsonDocAsMap.get("arg_1"));
+        assertEquals(Arrays.asList(3, 0, 1, 2, 4), nthListItemFromMap("tokens", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
     }
 
     public void testInputTooLarge() throws IOException {
         BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN),
             new DistilBertTokenization(null, null, 5)
         ).build();
         {
             DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
             ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-                () -> requestBuilder.buildRequest("Elasticsearch fun Elasticsearch fun Elasticsearch fun", "request1"));
+                () -> requestBuilder.buildRequest(List.of("Elasticsearch fun Elasticsearch fun Elasticsearch fun"), "request1"));
 
             assertThat(e.getMessage(),
                 containsString("Input too large. The tokenized input length [11] exceeds the maximum sequence length [5]"));
@@ -58,7 +61,41 @@ public class DistilBertRequestBuilderTests extends ESTestCase {
             DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
             // input will become 3 tokens + the Class and Separator token = 5 which is
             // our max sequence length
-            requestBuilder.buildRequest("Elasticsearch fun", "request1");
+            requestBuilder.buildRequest(List.of("Elasticsearch fun"), "request1");
         }
     }
+
+    @SuppressWarnings("unchecked")
+    public void testBatchWithPadding() throws IOException {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList(BertTokenizer.PAD_TOKEN, BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN,
+                "Elastic", "##search", "fun",
+                "Pancake", "day",
+                "my", "little", "red", "car",
+                "God", "##zilla"
+            ),
+            new BertTokenization(null, null, 512)
+        ).build();
+
+        DistilBertRequestBuilder requestBuilder = new DistilBertRequestBuilder(tokenizer);
+        NlpTask.Request request = requestBuilder.buildRequest(
+            List.of("Elasticsearch",
+                "my little red car",
+                "Godzilla day"), "request1");
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
+
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("tokens"), hasSize(3));
+        assertThat((List<List<Integer>>) jsonDocAsMap.get("arg_1"), hasSize(3));
+
+        assertEquals(Arrays.asList(1, 3, 4, 2, 0, 0), nthListItemFromMap("tokens", 0, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 0, 0), nthListItemFromMap("arg_1", 0, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(1, 8, 9, 10, 11, 2), nthListItemFromMap("tokens", 1, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 1), nthListItemFromMap("arg_1", 1, jsonDocAsMap));
+
+        assertEquals(Arrays.asList(1, 12, 13, 7, 2, 0), nthListItemFromMap("tokens", 2, jsonDocAsMap));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1, 0), nthListItemFromMap("arg_1", 2, jsonDocAsMap));
+    }
 }

+ 9 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -29,14 +29,14 @@ public class FillMaskProcessorTests extends ESTestCase {
     public void testProcessResults() {
         // only the scores of the MASK index array
         // are used the rest is filler
-        double[][] scores = {
+        double[][][] scores = {{
             { 0, 0, 0, 0, 0, 0, 0}, // The
             { 0, 0, 0, 0, 0, 0, 0}, // capital
             { 0, 0, 0, 0, 0, 0, 0}, // of
             { 0.01, 0.01, 0.3, 0.1, 0.01, 0.2, 1.2}, // MASK
             { 0, 0, 0, 0, 0, 0, 0}, // is
             { 0, 0, 0, 0, 0, 0, 0} // paris
-        };
+        }};
 
         String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris";
 
@@ -45,7 +45,8 @@ public class FillMaskProcessorTests extends ESTestCase {
         int[] tokenMap = new int[] {0, 1, 2, 3, 4, 5};
         int[] tokenIds = new int[] {0, 1, 2, 3, 4, 5};
 
-        TokenizationResult tokenization = new TokenizationResult(input, vocab, tokens, tokenIds, tokenMap);
+        TokenizationResult tokenization = new TokenizationResult(vocab);
+        tokenization.addTokenization(input, tokens, tokenIds, tokenMap);
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
 
@@ -66,21 +67,19 @@ public class FillMaskProcessorTests extends ESTestCase {
     }
 
     public void testProcessResults_GivenMissingTokens() {
-        TokenizationResult tokenization =
-            new TokenizationResult("", Collections.emptyList(), Collections.emptyList(),
-            new int[] {}, new int[] {});
+        TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
+        tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {});
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
-        PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][]{{}}, 0L, null);
-
+        PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null);
         FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult);
 
         assertThat(result.getPredictions(), empty());
     }
 
     public void testValidate_GivenMissingMaskToken() {
-        String input = "The capital of France is Paris";
+        List<String> input = List.of("The capital of France is Paris");
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
@@ -92,7 +91,7 @@ public class FillMaskProcessorTests extends ESTestCase {
 
 
     public void testProcessResults_GivenMultipleMaskTokens() {
-        String input = "The capital of [MASK] is [MASK]";
+        List<String> input = List.of("The capital of [MASK] is [MASK]");
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index", "vocab"), null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);

+ 6 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -95,7 +95,8 @@ public class NerProcessorTests extends ESTestCase {
             Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
             "Many use Elasticsearch in London"
         );
-        double[][] scores = {
+
+        double[][][] scores = {{
             { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // many
             { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // use
             { 0.01, 0.01, 0, 0.01, 0, 7, 0, 3, 0}, // el
@@ -103,7 +104,7 @@ public class NerProcessorTests extends ESTestCase {
             { 0, 0, 0, 0, 0, 0, 0, 0, 0}, // ##search
             { 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in
             { 0, 0, 0, 0, 0, 0, 0, 6, 0} // london
-        };
+        }};
         NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 1L, null));
 
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -133,13 +134,13 @@ public class NerProcessorTests extends ESTestCase {
             "Elasticsearch in London"
         );
 
-        double[][] scores = {
+        double[][][] scores = {{
             { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0}, // el
             { 0.01, 0.01, 0, 0, 0, 0, 0, 0, 0}, // ##astic
             { 0, 0, 0, 0, 0, 0, 0, 0, 0}, // ##search
             { 0, 0, 0, 0, 0, 0, 0, 0, 5}, // in
             { 6, 0, 0, 0, 0, 0, 0, 0, 0} // london
-        };
+        }};
         NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 1L, null));
 
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -225,6 +226,6 @@ public class NerProcessorTests extends ESTestCase {
                 new DistilBertTokenization(true, false, null)
             )
         ).setDoLowerCase(true).setWithSpecialTokens(false).build();
-        return tokenizer.tokenize(input);
+        return tokenizer.tokenize(List.of(input));
     }
 }

+ 13 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -13,11 +13,12 @@ import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DistilBertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -35,13 +36,13 @@ public class TextClassificationProcessorTests extends ESTestCase {
         TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
         TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config);
         {
-            PyTorchResult torchResult = new PyTorchResult("foo", new double[][] {}, 0L, null);
+            PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
             InferenceResults inferenceResults = processor.processResult(null, torchResult);
             assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
             assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning());
         }
         {
-            PyTorchResult torchResult = new PyTorchResult("foo", new double[][] { { 1.0 } }, 0L, null);
+            PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
             InferenceResults inferenceResults = processor.processResult(null, torchResult);
             assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
             assertEquals(
@@ -51,23 +52,25 @@ public class TextClassificationProcessorTests extends ESTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
     public void testBuildRequest() throws IOException {
-        BertTokenizer tokenizer = BertTokenizer.builder(
-            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN),
-            new BertTokenization(null, null, 512)
-        ).build();
+        NlpTokenizer tokenizer = NlpTokenizer.build(
+            new Vocabulary(
+                Arrays.asList("Elastic", "##search", "fun",
+                    BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN, BertTokenizer.PAD_TOKEN)),
+            new DistilBertTokenization(null, null, 512));
 
         TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index", "vocab"), null, null, null);
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
 
-        NlpTask.Request request = processor.buildRequest("Elasticsearch fun", "request1");
+        NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1");
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(3));
         assertEquals("request1", jsonDocAsMap.get("request_id"));
-        assertEquals(Arrays.asList(3, 0, 1, 2, 4), jsonDocAsMap.get("tokens"));
-        assertEquals(Arrays.asList(1, 1, 1, 1, 1), jsonDocAsMap.get("arg_1"));
+        assertEquals(Arrays.asList(3, 0, 1, 2, 4), ((List<List<Integer>>)jsonDocAsMap.get("tokens")).get(0));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), ((List<List<Integer>>)jsonDocAsMap.get("arg_1")).get(0));
     }
 
     public void testValidate() {

+ 56 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -13,8 +13,10 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.hasSize;
 
 public class BertTokenizerTests extends ESTestCase {
 
@@ -24,7 +26,8 @@ public class BertTokenizerTests extends ESTestCase {
             new BertTokenization(null, false, null)
         ).build();
 
-        TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
+        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
+        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun"));
         assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
@@ -36,7 +39,8 @@ public class BertTokenizerTests extends ESTestCase {
             Tokenization.createDefault()
         ).build();
 
-        TokenizationResult tokenization = tokenizer.tokenize("elasticsearch fun");
+        TokenizationResult tr = tokenizer.tokenize(List.of("elasticsearch fun"));
+        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
         assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]"));
         assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
@@ -52,7 +56,8 @@ public class BertTokenizerTests extends ESTestCase {
          .setWithSpecialTokens(false)
          .build();
 
-        TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
+        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch " + specialToken + " fun"));
+        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun"));
         assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
@@ -67,12 +72,14 @@ public class BertTokenizerTests extends ESTestCase {
              .setWithSpecialTokens(false)
              .build();
 
-            TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
+            TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
+            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
             assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun"));
             assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
             assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
 
-            tokenization = tokenizer.tokenize("elasticsearch fun");
+            tr = tokenizer.tokenize(List.of("elasticsearch fun"));
+            tokenization = tr.getTokenizations().get(0);
             assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
         }
 
@@ -82,7 +89,8 @@ public class BertTokenizerTests extends ESTestCase {
                 .setWithSpecialTokens(false)
                 .build();
 
-            TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
+            TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
+            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
             assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
         }
     }
@@ -93,14 +101,54 @@ public class BertTokenizerTests extends ESTestCase {
             Tokenization.createDefault()
         ).setWithSpecialTokens(false).build();
 
-        TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch, fun.");
+        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch, fun."));
+        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());
 
-        tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].");
+        tr = tokenizer.tokenize(List.of("Elasticsearch, fun [MASK]."));
+        tokenization = tr.getTokenizations().get(0);
         assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "[MASK]", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
     }
+
+    public void testBatchInput() {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList("Elastic", "##search", "fun",
+                "Pancake", "day",
+                "my", "little", "red", "car",
+                "God", "##zilla"
+                ),
+            new BertTokenization(null, false, null)
+        ).build();
+
+        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch",
+            "my little red car",
+            "Godzilla day",
+            "Godzilla Pancake red car day"
+            ));
+        assertThat(tr.getTokenizations(), hasSize(4));
+
+        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
+        assertThat(tokenization.getTokens(), contains("Elastic", "##search"));
+        assertArrayEquals(new int[] {0, 1}, tokenization.getTokenIds());
+        assertArrayEquals(new int[] {0, 0}, tokenization.getTokenMap());
+
+        tokenization = tr.getTokenizations().get(1);
+        assertThat(tokenization.getTokens(), contains("my", "little", "red", "car"));
+        assertArrayEquals(new int[] {5, 6, 7, 8}, tokenization.getTokenIds());
+        assertArrayEquals(new int[] {0, 1, 2, 3}, tokenization.getTokenMap());
+
+        tokenization = tr.getTokenizations().get(2);
+        assertThat(tokenization.getTokens(), contains("God", "##zilla", "day"));
+        assertArrayEquals(new int[] {9, 10, 4}, tokenization.getTokenIds());
+        assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
+
+        tokenization = tr.getTokenizations().get(3);
+        assertThat(tokenization.getTokens(), contains("God", "##zilla", "Pancake", "red", "car", "day"));
+        assertArrayEquals(new int[] {9, 10, 3, 7, 8, 4}, tokenization.getTokenIds());
+        assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
+    }
 }