瀏覽代碼

[ML] Fix WordPiece tokenization of unknown words with known subwords (#87510)

Unknown words containing known subwords are still unknown
David Kyle 3 年之前
父節點
當前提交
09d7e45adf

+ 5 - 0
docs/changelog/87510.yaml

@@ -0,0 +1,5 @@
+pr: 87510
+summary: Fix `WordPiece` tokenization of unknown words with known subwords
+area: Machine Learning
+type: bug
+issues: []

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/WordPieceTokenFilter.java

@@ -150,7 +150,6 @@ public final class WordPieceTokenFilter extends TokenFilter {
                 }
                 int encoding = vocabulary.get(currentValidSubStr);
                 WordPieceToken t = new WordPieceToken(currentValidSubStr, encoding, offsetAtt.startOffset(), offsetAtt.endOffset());
-                tokenizedValues.add(t);
                 tokens.add(t);
                 start = end;
             }
@@ -161,6 +160,7 @@ public final class WordPieceTokenFilter extends TokenFilter {
                 tokenizedValues.add(t);
                 termAtt.setEmpty().append(unknownToken);
             } else {
+                tokenizedValues.addAll(tokens);
                 current = captureState();
                 WordPieceToken token = tokens.removeFirst();
                 termAtt.setEmpty().append(token.charSequence());

+ 15 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -558,4 +558,19 @@ public class BertTokenizerTests extends ESTestCase {
             expectThrows(Exception.class, () -> tokenizer.tokenize("foo", "foo", Tokenization.Truncate.NONE, 0));
         }
     }
+
+    public void testUnknownWordWithKnownSubWords() {
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(
+                TEST_CASED_VOCAB,
+                new BertTokenization(null, false, null, Tokenization.Truncate.NONE, -1)
+            ).build()
+        ) {
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearchfoo fun", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("[UNK]", "fun"));
+            assertEquals(BertTokenizer.UNKNOWN_TOKEN, TEST_CASED_VOCAB.get(tokenization.tokenIds()[0]));
+            assertEquals("fun", TEST_CASED_VOCAB.get(tokenization.tokenIds()[1]));
+            assertArrayEquals(new int[] { 0, 1 }, tokenization.tokenMap());
+        }
+    }
 }