Browse Source

[ML] Preserve casing for never split tokens (#81429)

This fixes a bug introduced by #81254. We are now using
a token trie tree to merge tokens belonging to one of the
never-split tokens back together. However, if the tokenizer
is lower casing, then the merged token will also be lower case
and won't be matched against never split tokens that are expected
to be in upper case.

This commit fixes this by looking up the original text and only
merging tokens together when the original text is matching one
of the never split tokens.
Dimitris Athanasiou 3 years ago
parent
commit
86f31c267f

+ 11 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java

@@ -32,6 +32,7 @@ public class BasicTokenizer {
     private final boolean isLowerCase;
     private final boolean isTokenizeCjkChars;
     private final boolean isStripAccents;
+    private final Set<String> neverSplitTokens;
     private final TokenTrieNode neverSplitTokenTrieRoot;
 
     /**
@@ -46,6 +47,7 @@ public class BasicTokenizer {
         this.isLowerCase = isLowerCase;
         this.isTokenizeCjkChars = isTokenizeCjkChars;
         this.isStripAccents = isStripAccents;
+        this.neverSplitTokens = neverSplit;
         this.neverSplitTokenTrieRoot = TokenTrieNode.build(neverSplit, this::doTokenizeString);
     }
 
@@ -76,7 +78,7 @@ public class BasicTokenizer {
      * @return List of tokens
      */
     public List<DelimitedToken> tokenize(String text) {
-        return mergeNeverSplitTokens(doTokenize(text));
+        return mergeNeverSplitTokens(text, doTokenize(text));
     }
 
     private List<String> doTokenizeString(String text) {
@@ -111,7 +113,7 @@ public class BasicTokenizer {
         return processedTokens;
     }
 
-    private List<DelimitedToken> mergeNeverSplitTokens(List<DelimitedToken> tokens) {
+    private List<DelimitedToken> mergeNeverSplitTokens(String originalText, List<DelimitedToken> tokens) {
         if (neverSplitTokenTrieRoot.isLeaf()) {
             return tokens;
         }
@@ -129,7 +131,13 @@ public class BasicTokenizer {
                 mergedTokens.add(token);
             } else if (childNode.isLeaf()) {
                 matchingTokens.add(token);
-                mergedTokens.add(DelimitedToken.mergeTokens(matchingTokens));
+                DelimitedToken mergedToken = DelimitedToken.mergeTokens(matchingTokens);
+                String originalTokenText = originalText.substring(mergedToken.getStartPos(), mergedToken.getEndPos());
+                if (neverSplitTokens.contains(originalTokenText)) {
+                    mergedTokens.add(new DelimitedToken(mergedToken.getStartPos(), mergedToken.getEndPos(), originalTokenText));
+                } else {
+                    mergedTokens.addAll(matchingTokens);
+                }
                 matchingTokens = new ArrayList<>();
                 current = neverSplitTokenTrieRoot;
             } else {

+ 27 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java

@@ -61,7 +61,7 @@ public class BasicTokenizerTests extends ESTestCase {
         assertThat(tokenStrings(tokens), contains("HaLLo", "!", "how", "Are", "yoU", "?"));
     }
 
-    public void testNeverSplit() {
+    public void testNeverSplit_GivenNoLowerCase() {
         BasicTokenizer tokenizer = new BasicTokenizer(false, false, false, Collections.singleton("[UNK]"));
         var tokens = tokenizer.tokenize(" \tHeLLo!how  \n Are yoU? [UNK]");
         assertThat(tokenStrings(tokens), contains("HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"));
@@ -77,8 +77,32 @@ public class BasicTokenizerTests extends ESTestCase {
 
         tokens = tokenizer.tokenize("Hello-[UNK]");
         assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]"));
-        tokens = tokenizer.tokenize("Hello-[UNK][UNK]");
-        assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]", "[UNK]"));
+        tokens = tokenizer.tokenize("Hello~[UNK][UNK]");
+        assertThat(tokenStrings(tokens), contains("Hello", "~", "[UNK]", "[UNK]"));
+        tokens = tokenizer.tokenize("Hello-[unk]");
+        assertThat(tokenStrings(tokens), contains("Hello", "-", "[", "unk", "]"));
+    }
+
+    public void testNeverSplit_GivenLowerCase() {
+        BasicTokenizer tokenizer = new BasicTokenizer(true, false, false, Collections.singleton("[UNK]"));
+        var tokens = tokenizer.tokenize(" \tHeLLo!how  \n Are yoU? [UNK]");
+        assertThat(tokenStrings(tokens), contains("hello", "!", "how", "are", "you", "?", "[UNK]"));
+
+        tokens = tokenizer.tokenize("Hello [UNK].");
+        assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "."));
+
+        tokens = tokenizer.tokenize("Hello [UNK]?");
+        assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "?"));
+
+        tokens = tokenizer.tokenize("Hello [UNK]!!");
+        assertThat(tokenStrings(tokens), contains("hello", "[UNK]", "!", "!"));
+
+        tokens = tokenizer.tokenize("Hello-[UNK]");
+        assertThat(tokenStrings(tokens), contains("hello", "-", "[UNK]"));
+        tokens = tokenizer.tokenize("Hello~[UNK][UNK]");
+        assertThat(tokenStrings(tokens), contains("hello", "~", "[UNK]", "[UNK]"));
+        tokens = tokenizer.tokenize("Hello-[unk]");
+        assertThat(tokenStrings(tokens), contains("hello", "-", "[", "unk", "]"));
     }
 
     public void testSplitOnPunctuation() {