Browse Source

[ML] fix tokenization bug when handling normalization in BERT and MPNet (#92329)

There is a small bug in how we handle normalization and accent stripping in BERT and MPNet. 

The bug may return something like this:
```
'java.lang.IllegalArgumentException: startOffset must be non-negative, and endOffset must be >= startOffset; got startOffset=553,endOffset=552', 'caused_by': {'type': 'illegal_argument_exception', 'reason': 'startOffset must be non-negative, and endOffset must be >= startOffset; got startOffset=553,endOffset=552'
```

The cause of the bug is when we normalize strings, then attempt to split them on a punctuation mark. Since we don't really handle offset changes on normalization, splitting on punctuation could cause some weirdness. Here is an example token: `Br창n's`. In the old way, we would normalize the string (which decomposes `창` thus increasing the character count), and then split on `'`. Since the decomposition increased the character count, our offsets could break as we calculate them based on character count (which is now increased).

To simplify the logic, I changed normalization to be done AFTER our split on punctuation. This way our offsets still refer to the original input offsets and if we add or even remove characters in normalization, it is of no consequence as our punctuation split is already handled.

closes: https://github.com/elastic/elasticsearch/issues/92243
Benjamin Trent 2 years ago
parent
commit
143fe5b1c7

+ 5 - 0
docs/changelog/92329.yaml

@@ -0,0 +1,5 @@
+pr: 92329
+summary: Fix tokenization bug when handling normalization in BERT and MPNet
+area: Machine Learning
+type: bug
+issues: []

+ 16 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilter.java

@@ -118,19 +118,19 @@ public final class BasicTokenFilter extends TokenFilter {
         }
         }
         current = null; // not really needed, but for safety
         current = null; // not really needed, but for safety
         if (input.incrementToken()) {
         if (input.incrementToken()) {
-            if (isStripAccents) {
-                stripAccent();
-            }
             if (neverSplitSet.contains(termAtt)) {
             if (neverSplitSet.contains(termAtt)) {
                 return true;
                 return true;
             }
             }
             // split punctuation and maybe cjk chars!!!
             // split punctuation and maybe cjk chars!!!
             LinkedList<DelimitedToken> splits = split();
             LinkedList<DelimitedToken> splits = split();
-            // There is nothing to merge, nothing to store, simply return
-            if (splits.size() == 1) {
-                return true;
+            LinkedList<DelimitedToken> delimitedTokens = mergeSplits(splits);
+            if (isStripAccents) {
+                for (DelimitedToken token : delimitedTokens) {
+                    tokens.add(stripAccent(token));
+                }
+            } else {
+                tokens.addAll(delimitedTokens);
             }
             }
-            tokens.addAll(mergeSplits(splits));
             this.current = captureState();
             this.current = captureState();
             DelimitedToken token = tokens.removeFirst();
             DelimitedToken token = tokens.removeFirst();
             termAtt.setEmpty().append(token.charSequence());
             termAtt.setEmpty().append(token.charSequence());
@@ -140,14 +140,14 @@ public final class BasicTokenFilter extends TokenFilter {
         return false;
         return false;
     }
     }
 
 
-    private void stripAccent() {
+    private DelimitedToken stripAccent(DelimitedToken token) {
         accentBuffer.setLength(0);
         accentBuffer.setLength(0);
         boolean changed = false;
         boolean changed = false;
-        if (normalizer.quickCheck(termAtt) != Normalizer.YES) {
-            normalizer.normalize(termAtt, accentBuffer);
+        if (normalizer.quickCheck(token.charSequence()) != Normalizer.YES) {
+            normalizer.normalize(token.charSequence(), accentBuffer);
             changed = true;
             changed = true;
         } else {
         } else {
-            accentBuffer.append(termAtt);
+            accentBuffer.append(token.charSequence());
         }
         }
         List<Integer> badIndices = new ArrayList<>();
         List<Integer> badIndices = new ArrayList<>();
         List<Integer> charCount = new ArrayList<>();
         List<Integer> charCount = new ArrayList<>();
@@ -172,8 +172,9 @@ public final class BasicTokenFilter extends TokenFilter {
             }
             }
         }
         }
         if (changed) {
         if (changed) {
-            termAtt.setEmpty().append(accentBuffer);
+            return new DelimitedToken(accentBuffer.toString(), token.startOffset(), token.endOffset());
         }
         }
+        return token;
     }
     }
 
 
     private LinkedList<DelimitedToken> split() {
     private LinkedList<DelimitedToken> split() {
@@ -210,6 +211,9 @@ public final class BasicTokenFilter extends TokenFilter {
     }
     }
 
 
     private LinkedList<DelimitedToken> mergeSplits(LinkedList<DelimitedToken> splits) {
     private LinkedList<DelimitedToken> mergeSplits(LinkedList<DelimitedToken> splits) {
+        if (splits.size() == 1) {
+            return splits;
+        }
         LinkedList<DelimitedToken> mergedTokens = new LinkedList<>();
         LinkedList<DelimitedToken> mergedTokens = new LinkedList<>();
         List<DelimitedToken> matchingTokens = new ArrayList<>();
         List<DelimitedToken> matchingTokens = new ArrayList<>();
         CharSeqTokenTrieNode current = neverSplit;
         CharSeqTokenTrieNode current = neverSplit;

+ 29 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -67,6 +67,35 @@ public class BertTokenizerTests extends ESTestCase {
         }
         }
     }
     }
 
 
+    public void testTokenizeFailureCaseAccentFilter() {
+        List<String> testingVocab = List.of(
+            "[CLS]",
+            "br",
+            "##ᄎ",
+            "##ᅡ",
+            "##ᆼ",
+            "##n",
+            "'",
+            "s",
+            "[SEP]",
+            BertTokenizer.MASK_TOKEN,
+            BertTokenizer.UNKNOWN_TOKEN,
+            BertTokenizer.PAD_TOKEN
+        );
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(
+                testingVocab,
+                new BertTokenization(true, true, 512, Tokenization.Truncate.FIRST, -1)
+            ).build()
+        ) {
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Br창n's", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenization.tokenIds(), equalTo(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8 }));
+
+            tokenization = tokenizer.tokenize("Br창n", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenization.tokenIds(), equalTo(new int[] { 0, 1, 2, 3, 4, 5, 8 }));
+        }
+    }
+
     public void testTokenizeLargeInputNoTruncation() {
     public void testTokenizeLargeInputNoTruncation() {
         try (
         try (
             BertTokenizer tokenizer = BertTokenizer.builder(
             BertTokenizer tokenizer = BertTokenizer.builder(