|
@@ -17,6 +17,8 @@ import java.util.List;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.hamcrest.Matchers.contains;
|
|
|
+import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.not;
|
|
|
|
|
|
public class XLMRobertaTokenizerTests extends ESTestCase {
|
|
|
|
|
@@ -37,6 +39,8 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
|
|
|
"▁little",
|
|
|
"▁red",
|
|
|
"▁car",
|
|
|
+ "▁😀",
|
|
|
+ "▁🇸🇴",
|
|
|
"<mask>",
|
|
|
"."
|
|
|
);
|
|
@@ -57,6 +61,8 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
|
|
|
-11.451579093933105,
|
|
|
-10.858806610107422,
|
|
|
-10.214239120483398,
|
|
|
+ -10.230172157287598,
|
|
|
+ -9.451579093933105,
|
|
|
0.0,
|
|
|
-3.0
|
|
|
);
|
|
@@ -81,6 +87,43 @@ public class XLMRobertaTokenizerTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testSurrogatePair() throws IOException {
|
|
|
+ try (
|
|
|
+ XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(
|
|
|
+ TEST_CASE_VOCAB,
|
|
|
+ TEST_CASE_SCORES,
|
|
|
+ new XLMRobertaTokenization(false, null, Tokenization.Truncate.NONE, -1)
|
|
|
+ ).build()
|
|
|
+ ) {
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("😀", Tokenization.Truncate.NONE, -1, 0).get(0);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁\uD83D\uDE00"));
|
|
|
+
|
|
|
+ tokenization = tokenizer.tokenize("Elasticsearch 😀", Tokenization.Truncate.NONE, -1, 0).get(0);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00"));
|
|
|
+
|
|
|
+ tokenization = tokenizer.tokenize("Elasticsearch 😀 fun", Tokenization.Truncate.NONE, -1, 0).get(0);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00", "▁fun"));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testMultiByteEmoji() throws IOException {
|
|
|
+ try (
|
|
|
+ XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(
|
|
|
+ TEST_CASE_VOCAB,
|
|
|
+ TEST_CASE_SCORES,
|
|
|
+ new XLMRobertaTokenization(false, null, Tokenization.Truncate.NONE, -1)
|
|
|
+ ).build()
|
|
|
+ ) {
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("🇸🇴", Tokenization.Truncate.NONE, -1, 0).get(0);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁🇸🇴"));
|
|
|
+ assertThat(tokenization.tokenIds()[0], not(equalTo(3))); // not the unknown token
|
|
|
+
|
|
|
+ tokenization = tokenizer.tokenize("🏁", Tokenization.Truncate.NONE, -1, 0).get(0);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁🏁"));
|
|
|
+ assertThat(tokenization.tokenIds()[0], equalTo(3)); // the unknown token (not in the vocabulary)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testTokenizeWithNeverSplit() throws IOException {
|
|
|
try (
|
|
|
XLMRobertaTokenizer tokenizer = XLMRobertaTokenizer.builder(
|