Przeglądaj źródła

[ML] Handle multiple punctuation chars adjacent to never-split token (#78607)

This commit fixes a problem when tokenizing text that contains
tokens that should never be split followed by multiple punctuation
characters.
Dimitris Athanasiou 4 lat temu
rodzic
commit
45a4149401

+ 19 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java

@@ -101,12 +101,14 @@ public class BasicTokenizer {
 
 
             // At this point text has been tokenized by whitespace
             // At this point text has been tokenized by whitespace
             // but one of the special never split tokens could be adjacent
             // but one of the special never split tokens could be adjacent
-            // to a punctuation character.
-            if (isCommonPunctuation(token.codePointAt(token.length() -1)) &&
-                    neverSplit.contains(token.substring(0, token.length() -1))) {
-                processedTokens.add(token.substring(0, token.length() -1));
-                processedTokens.add(token.substring(token.length() -1));
-                continue;
+            // to one or more punctuation characters.
+            if (isCommonPunctuation(token.codePointAt(token.length() -1))) {
+                int lastNonPunctuationIndex = findLastNonPunctuationIndex(token);
+                if (lastNonPunctuationIndex >= 0 && neverSplit.contains(token.substring(0, lastNonPunctuationIndex + 1))) {
+                    processedTokens.add(token.substring(0, lastNonPunctuationIndex + 1));
+                    processedTokens.addAll(splitOnPunctuation(token.substring(lastNonPunctuationIndex + 1)));
+                    continue;
+                }
             }
             }
 
 
             if (isLowerCase) {
             if (isLowerCase) {
@@ -121,6 +123,17 @@ public class BasicTokenizer {
         return processedTokens;
         return processedTokens;
     }
     }
 
 
+    private int findLastNonPunctuationIndex(String token) {
+        int i = token.length() - 1;
+        while (i >= 0) {
+            if (isCommonPunctuation(token.codePointAt(i)) == false) {
+                break;
+            }
+            i--;
+        }
+        return i;
+    }
+
     public boolean isLowerCase() {
     public boolean isLowerCase() {
         return isLowerCase;
         return isLowerCase;
     }
     }

+ 6 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java

@@ -69,6 +69,9 @@ public class BasicTokenizerTests extends ESTestCase {
 
 
         tokens = tokenizer.tokenize("Hello [UNK]?");
         tokens = tokenizer.tokenize("Hello [UNK]?");
         assertThat(tokens, contains("Hello", "[UNK]", "?"));
         assertThat(tokens, contains("Hello", "[UNK]", "?"));
+
+        tokens = tokenizer.tokenize("Hello [UNK]!!");
+        assertThat(tokens, contains("Hello", "[UNK]", "!", "!"));
     }
     }
 
 
     public void testSplitOnPunctuation() {
     public void testSplitOnPunctuation() {
@@ -92,6 +95,9 @@ public class BasicTokenizerTests extends ESTestCase {
 
 
         tokens = BasicTokenizer.splitOnPunctuation("hi.");
         tokens = BasicTokenizer.splitOnPunctuation("hi.");
         assertThat(tokens, contains("hi", "."));
         assertThat(tokens, contains("hi", "."));
+
+        tokens = BasicTokenizer.splitOnPunctuation("!!");
+        assertThat(tokens, contains("!", "!"));
     }
     }
 
 
     public void testStripAccents() {
     public void testStripAccents() {