Browse Source

[ML] Guard against input sequences that are too long for Question Answering models (#91924)

Adds checks and rejects requests where question + span is > 
max sequence length.
David Kyle 2 years ago
parent
commit
6c643c59fe

+ 5 - 0
docs/changelog/91924.yaml

@@ -0,0 +1,5 @@
+pr: 91924
+summary: Guard against input sequences that are too long for Question Answering models
+area: Machine Learning
+type: bug
+issues: []

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java

@@ -94,7 +94,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             boolean previouslyLicensed
         ) {
             this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
-            this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
+            this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName()));
             this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
             this.previouslyLicensed = previouslyLicensed;
             this.timeout = timeout;
@@ -112,7 +112,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
         public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfigUpdate update, boolean previouslyLicensed) {
             this(
                 modelId,
-                Collections.singletonList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
+                Collections.singletonList(ExceptionsHelper.requireNonNull(objectToInfer, DOCS.getPreferredName())),
                 update,
                 TimeValue.MAX_VALUE,
                 previouslyLicensed

+ 29 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Strings;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
@@ -227,15 +228,16 @@ public abstract class NlpTokenizer implements Releasable {
      * @return tokenization result for the sequence pair
      */
     public List<TokenizationResult.Tokens> tokenize(String seq1, String seq2, Tokenization.Truncate truncate, int span, int sequenceId) {
+        if (isWithSpecialTokens() == false) {
+            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
+        }
+
         var innerResultSeq1 = innerTokenize(seq1);
         List<? extends DelimitedToken.Encoded> tokenIdsSeq1 = innerResultSeq1.tokens;
         List<Integer> tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap;
         var innerResultSeq2 = innerTokenize(seq2);
         List<? extends DelimitedToken.Encoded> tokenIdsSeq2 = innerResultSeq2.tokens;
         List<Integer> tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap;
-        if (isWithSpecialTokens() == false) {
-            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
-        }
         int extraTokens = getNumExtraTokensForSeqPair();
         int numTokens = tokenIdsSeq1.size() + tokenIdsSeq2.size() + extraTokens;
 
@@ -296,6 +298,30 @@ public abstract class NlpTokenizer implements Releasable {
         List<Integer> seq1TokenIds = tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList());
 
         final int trueMaxSeqLength = maxSequenceLength() - extraTokens - tokenIdsSeq1.size();
+        if (trueMaxSeqLength <= 0) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Unable to do sequence pair tokenization: the first sequence [%d tokens] "
+                        + "is longer than the max sequence length [%d tokens]",
+                    tokenIdsSeq1.size() + extraTokens,
+                    maxSequenceLength()
+                )
+            );
+        }
+
+        if (span > trueMaxSeqLength) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Unable to do sequence pair tokenization: the combined first sequence and span length [%d + %d = %d tokens] "
+                        + "is longer than the max sequence length [%d tokens]. Reduce the size of the [span] window.",
+                    tokenIdsSeq1.size(),
+                    span,
+                    tokenIdsSeq1.size() + span,
+                    maxSequenceLength()
+                )
+            );
+        }
+
         while (splitEndPos < tokenIdsSeq2.size()) {
             splitEndPos = Math.min(splitStartPos + trueMaxSeqLength, tokenIdsSeq2.size());
             // Make sure we do not end on a word

+ 45 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -19,6 +19,7 @@ import java.util.Set;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 
@@ -469,6 +470,50 @@ public class BertTokenizerTests extends ESTestCase {
         }
     }
 
+    public void testMultiSeqTokenizationWithSpanFirstInputTooLong() {
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault())
+                .setDoLowerCase(false)
+                .setWithSpecialTokens(true)
+                .setMaxSequenceLength(3)
+                .build()
+        ) {
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.NONE, 2, 0)
+            );
+            assertThat(
+                iae.getMessage(),
+                containsString(
+                    "Unable to do sequence pair tokenization: the first sequence [7 tokens] "
+                        + "is longer than the max sequence length [3 tokens]"
+                )
+            );
+        }
+    }
+
+    public void testMultiSeqTokenizationWithSpanPlusFirstInputTooLong() {
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault())
+                .setDoLowerCase(false)
+                .setWithSpecialTokens(true)
+                .setMaxSequenceLength(8)
+                .build()
+        ) {
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car", Tokenization.Truncate.NONE, 5, 0)
+            );
+            assertThat(
+                iae.getMessage(),
+                containsString(
+                    "Unable to do sequence pair tokenization: the combined first sequence and span length [4 + 5 = 9 tokens] "
+                        + "is longer than the max sequence length [8 tokens]. Reduce the size of the [span] window."
+                )
+            );
+        }
+    }
+
     public void testTokenizeLargeInputMultiSequenceTruncation() {
         try (
             BertTokenizer tokenizer = BertTokenizer.builder(