Browse Source

[ML] Fix handling surrogate pairs in the XLM Roberta tokenizer (#105183)

UTF16 represents some characters as surrogate pairs which are represented
by 2 UTF16 characters, often emojis are encoded as surrogate pairs. This PR
fixes an error in calculating the number of bytes required to convert a UTF16
string to UTF8 as surrogate pairs were not processed properly
David Kyle 1 year ago
parent
commit
47828788d9

+ 7 - 0
docs/changelog/105183.yaml

@@ -0,0 +1,7 @@
+pr: 105183
+summary: Fix handling surrogate pairs in the XLM Roberta tokenizer
+area: Machine Learning
+type: bug
+issues:
+ - 104626
+ - 104981

+ 13 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java

@@ -33,8 +33,6 @@ import java.util.Locale;
 import java.util.Optional;
 import java.util.OptionalInt;
 
-import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils.numUtf8Bytes;
-
 /**
  * This is custom normalizer logic purpose built to replicate the logic in DoubleArray Trie System (darts)
  * object and the sentence piece normalizer.
@@ -179,19 +177,14 @@ public class PrecompiledCharMapNormalizer extends BaseCharFilter {
         b.setText(str);
         // We iterate the whole string, so b.first() is always `0`
         int startIter = b.first();
-        int codePointPos = 0;
         CharsRefBuilder strBuilder = new CharsRefBuilder();
         strBuilder.grow(strBytes.length);
         int bytePos = 0;
         int normalizedCharPos = 0;
         // Keep in mind, these break points aren't necessarily surrogate pairs, but also codepoints that contain a combining mark
         for (int end = b.next(); end != BreakIterator.DONE; startIter = end, end = b.next()) {
-            int byteLen = 0;
-            int numCp = Character.codePointCount(str, startIter, end);
-            for (int i = codePointPos; i < numCp + codePointPos; i++) {
-                byteLen += numUtf8Bytes(strCp[i]);
-            }
-            codePointPos += numCp;
+            int byteLen = UnicodeUtil.calcUTF16toUTF8Length(str, startIter, end - startIter);
+
             // The trie only go up to a depth of 5 bytes.
             // So even looking at it for graphemes (with combining, surrogate, etc.) that are 6+ bytes in length is useless.
             if (byteLen < 6) {
@@ -209,8 +202,12 @@ public class PrecompiledCharMapNormalizer extends BaseCharFilter {
                 }
             }
             int charByteIndex = 0;
-            for (int i = startIter; i < end; i++) {
-                int utf8CharBytes = numUtf8Bytes(str.charAt(i));
+            int i = startIter;
+            while (i < end) {
+                boolean isSurrogatePair = (i + 1 < end && Character.isSurrogatePair(str.charAt(i), str.charAt(i + 1)));
+                int numUtf16Chars = isSurrogatePair ? 2 : 1;
+
+                int utf8CharBytes = UnicodeUtil.calcUTF16toUTF8Length(str, i, numUtf16Chars);
                 Optional<BytesRef> maybeSubStr = normalizePart(strBytes, charByteIndex + bytePos, utf8CharBytes);
                 if (maybeSubStr.isPresent()) {
                     BytesRef subStr = maybeSubStr.get();
@@ -226,8 +223,13 @@ public class PrecompiledCharMapNormalizer extends BaseCharFilter {
                 } else {
                     normalizedCharPos += 1;
                     strBuilder.append(str.charAt(i));
+                    if (isSurrogatePair) {
+                        strBuilder.append(str.charAt(i + 1));
+                    }
                 }
                 charByteIndex += utf8CharBytes;
+
+                i = i + numUtf16Chars;
             }
             bytePos += byteLen;
         }

+ 0 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizerUtils.java

@@ -64,19 +64,6 @@ final class TokenizerUtils {
         return bigTokens;
     }
 
-    static int numUtf8Bytes(int c) {
-        if (c < 128) {
-            return 1;
-        }
-        if (c < 2048) {
-            return 2;
-        }
-        if (c < 65536) {
-            return 3;
-        }
-        return 4;
-    }
-
     public record CharSequenceRef(CharSequence wrapped, int offset, int len) implements CharSequence {
 
         public int getOffset() {

+ 8 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java

@@ -30,7 +30,6 @@ import java.util.Objects;
 import java.util.Optional;
 
 import static org.elasticsearch.core.Strings.format;
-import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils.numUtf8Bytes;
 import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils.splitOutNeverSplit;
 
 /**
@@ -256,9 +255,14 @@ public final class UnigramTokenizer extends Tokenizer {
         BestPathNode[] bestPathNodes = new BestPathNode[numBytes + 1];
         int bytePos = 0;
         int charPos = 0;
-        while (bytePos < numBytes) {
+        while (charPos < inputSequence.length()) {
             double bestScoreTillHere = bestPathNodes[bytePos] == null ? 0 : bestPathNodes[bytePos].score;
-            int mblen = numUtf8Bytes(inputSequence.charAt(charPos));
+
+            boolean isSurrogatePair = (charPos + 1 < inputSequence.length()
+                && Character.isSurrogatePair(inputSequence.charAt(charPos), inputSequence.charAt(charPos + 1)));
+            int numUtf16Chars = isSurrogatePair ? 2 : 1;
+            int mblen = UnicodeUtil.calcUTF16toUTF8Length(inputSequence, charPos, numUtf16Chars);
+
             boolean hasSingleNode = false;
             // Find the matching prefixes, incrementing by the chars, each time
             for (BytesRef prefix : vocabTrie.matchingPrefixes(new BytesRef(normalizedByteBuffer, bytePos, numBytes - bytePos))) {
@@ -295,7 +299,7 @@ public final class UnigramTokenizer extends Tokenizer {
             }
             // Move our prefix search to the next char
             bytePos += mblen;
-            ++charPos;
+            charPos = charPos + numUtf16Chars;
         }
         int endsAtBytes = numBytes;
         int endsAtChars = inputSequence.length();

+ 11 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java

@@ -46,6 +46,17 @@ public class PrecompiledCharMapNormalizerTests extends ESTestCase {
         assertNormalization("​​από", parsed, "  από");
     }
 
+    public void testSurrogatePairScenario() throws IOException {
+        PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap();
+        assertNormalization("🇸🇴", parsed, "🇸🇴");
+        assertNormalization("🇸🇴", parsed, "\uD83C\uDDF8\uD83C\uDDF4");
+    }
+
+    public void testEmoji() throws IOException {
+        PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap();
+        assertNormalization("😀", parsed, "😀");
+    }
+
     private void assertNormalization(String input, PrecompiledCharMapNormalizer.Config config, String expected) throws IOException {
         PrecompiledCharMapNormalizer normalizer = new PrecompiledCharMapNormalizer(
             config.offsets(),

+ 43 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizerTests.java

@@ -17,6 +17,8 @@ import java.util.List;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
 
 public class XLMRobertaTokenizerTests extends ESTestCase {
 
@@ -37,6 +39,8 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
         "▁little",
         "▁red",
         "▁car",
+        "▁😀",
+        "▁🇸🇴",
         "<mask>",
         "."
     );
@@ -57,6 +61,8 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
         -11.451579093933105,
         -10.858806610107422,
         -10.214239120483398,
+        -10.230172157287598,
+        -9.451579093933105,
         0.0,
         -3.0
     );
@@ -81,6 +87,43 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
         }
     }
 
+    public void testSurrogatePair() throws IOException {
+        try (
+            XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(
+                TEST_CASE_VOCAB,
+                TEST_CASE_SCORES,
+                new XLMRobertaTokenization(false, null, Tokenization.Truncate.NONE, -1)
+            ).build()
+        ) {
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("😀", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁\uD83D\uDE00"));
+
+            tokenization = tokenizer.tokenize("Elasticsearch 😀", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00"));
+
+            tokenization = tokenizer.tokenize("Elasticsearch 😀 fun", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00", "▁fun"));
+        }
+    }
+
+    public void testMultiByteEmoji() throws IOException {
+        try (
+            XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(
+                TEST_CASE_VOCAB,
+                TEST_CASE_SCORES,
+                new XLMRobertaTokenization(false, null, Tokenization.Truncate.NONE, -1)
+            ).build()
+        ) {
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("🇸🇴", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁🇸🇴"));
+            assertThat(tokenization.tokenIds()[0], not(equalTo(3))); // not the unknown token
+
+            tokenization = tokenizer.tokenize("🏁", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁🏁"));
+            assertThat(tokenization.tokenIds()[0], equalTo(3)); // the unknown token (not in the vocabulary)
+        }
+    }
+
     public void testTokenizeWithNeverSplit() throws IOException {
         try (
             XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(