Browse Source

[ML] fix minor XLM roberta tokenization bug (#96807)

The current XLMRoberta tokenization had two bugs:

 - We were not post-processing tokens appropriately and using the Roberta BPE post processing
 - When tokenizing `<mask>` or other never_split tokens, they were not being added to the tokenized results used by our NLP process

This commit fixes both. Since these are bug fixes for an unreleased
feature, marking as non-issue
Benjamin Trent 2 years ago
parent
commit
09103b7b16

+ 13 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java

@@ -146,13 +146,20 @@ public final class UnigramTokenizer extends Tokenizer {
         offsetAtt.setOffset(correctOffset(whitespaceTokenizer.finalOffset), correctOffset(whitespaceTokenizer.finalOffset));
     }
 
-    @Override
-    public boolean incrementToken() throws IOException {
-        clearAttributes();
+    private void popFromTokens() {
         if (tokens.isEmpty() == false) {
             DelimitedToken.Encoded token = tokens.removeFirst();
+            tokenizedValues.add(token);
             termAtt.setEmpty().append(token.charSequence());
             offsetAtt.setOffset(token.startOffset(), token.endOffset());
+        }
+    }
+
+    @Override
+    public boolean incrementToken() throws IOException {
+        clearAttributes();
+        if (tokens.isEmpty() == false) {
+            popFromTokens();
             return true;
         }
         // First, whitespace tokenize
@@ -160,7 +167,7 @@ public final class UnigramTokenizer extends Tokenizer {
         if (whitespaceToken != null) {
             if (neverSplitHash.contains(whitespaceToken.charSequence())) {
                 Integer maybeTokenized = vocabToId.get(new BytesRef(whitespaceToken.charSequence()));
-                tokenizedValues.add(
+                tokens.add(
                     new DelimitedToken.Encoded(
                         whitespaceToken.charSequence().toString(),
                         Objects.requireNonNullElse(maybeTokenized, unknownTokenId),
@@ -168,7 +175,7 @@ public final class UnigramTokenizer extends Tokenizer {
                         correctOffset(whitespaceToken.endOffset())
                     )
                 );
-                offsetAtt.setOffset(correctOffset(whitespaceToken.startOffset()), correctOffset(whitespaceToken.endOffset()));
+                popFromTokens();
                 return true;
             }
             int inputOffsetStart = whitespaceToken.startOffset();
@@ -217,12 +224,9 @@ public final class UnigramTokenizer extends Tokenizer {
                     MultiCharSequence.from(PREFIX, token.charSequence()),
                     offsetCorrectorFunction
                 );
-                tokenizedValues.addAll(tokenList);
                 tokens.addAll(tokenList);
             }
-            DelimitedToken.Encoded token = tokens.removeFirst();
-            termAtt.setEmpty().append(token.charSequence());
-            offsetAtt.setOffset(token.startOffset(), token.endOffset());
+            popFromTokens();
             return true;
         }
         return false;

+ 33 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizationResult.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
+
+import java.util.List;
+
+import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.UnigramTokenizer.PREFIX;
+
+public class XLMRobertaTokenizationResult extends RobertaTokenizationResult {
+
+    protected XLMRobertaTokenizationResult(List<String> vocab, List<Tokens> tokenizations, int padTokenId) {
+        super(vocab, tokenizations, padTokenId);
+    }
+
+    @Override
+    public String decode(String token) {
+        if (token.startsWith(PREFIX)) {
+            return token.substring(PREFIX.length());
+        }
+        return token;
+    }
+
+    static class XLMRobertaTokensBuilder extends RobertaTokensBuilder {
+        XLMRobertaTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) {
+            super(withSpecialTokens, clsTokenId, sepTokenId);
+        }
+    }
+}

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java

@@ -121,7 +121,7 @@ public class XLMRobertaTokenizer extends NlpTokenizer {
 
     @Override
     public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
-        return new RobertaTokenizationResult(originalVocab, tokenizations, padTokenId);
+        return new XLMRobertaTokenizationResult(originalVocab, tokenizations, padTokenId);
     }
 
     @Override
@@ -160,7 +160,7 @@ public class XLMRobertaTokenizer extends NlpTokenizer {
 
     @Override
     TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
-        return new RobertaTokenizationResult.RobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
+        return new XLMRobertaTokenizationResult.XLMRobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
     }
 
     @Override

+ 17 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizerTests.java

@@ -37,7 +37,8 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
         "▁little",
         "▁red",
         "▁car",
-        "<mask>"
+        "<mask>",
+        "."
     );
     private static final List<Double> TEST_CASE_SCORES = List.of(
         0.0,
@@ -56,7 +57,8 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
         -11.451579093933105,
         -10.858806610107422,
         -10.214239120483398,
-        0.0
+        0.0,
+        -3.0
     );
 
     private List<String> tokenStrings(List<? extends DelimitedToken> tokens) {
@@ -78,6 +80,19 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
         }
     }
 
+    public void testTokenizeWithNeverSplit() throws IOException {
+        try (
+            XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(
+                TEST_CASE_VOCAB,
+                TEST_CASE_SCORES,
+                new XLMRobertaTokenization(false, null, Tokenization.Truncate.NONE, -1)
+            ).build()
+        ) {
+            TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch .<mask>.", Tokenization.Truncate.NONE, -1, 0).get(0);
+            assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁", ".", "<mask>", "▁", "."));
+        }
+    }
+
     public void testMultiSeqTokenization() throws IOException {
         try (
             XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(

File diff suppressed because it is too large
+ 0 - 1
x-pack/plugin/ml/src/test/resources/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json


Some files were not shown because too many files changed in this diff