|
@@ -47,8 +47,8 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
BertTokenizer.PAD_TOKEN
|
|
|
);
|
|
|
|
|
|
- private List<String> tokenStrings(List<WordPieceTokenFilter.WordPieceToken> tokens) {
|
|
|
- return tokens.stream().map(WordPieceTokenFilter.WordPieceToken::toString).collect(Collectors.toList());
|
|
|
+ private List<String> tokenStrings(List<? extends DelimitedToken> tokens) {
|
|
|
+ return tokens.stream().map(DelimitedToken::toString).collect(Collectors.toList());
|
|
|
}
|
|
|
|
|
|
public void testTokenize() {
|
|
@@ -58,10 +58,10 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
new BertTokenization(null, false, null, Tokenization.Truncate.NONE)
|
|
|
).build()
|
|
|
) {
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
- assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", "fun"));
|
|
|
- assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", "fun"));
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 3 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -103,11 +103,11 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
).build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize(
|
|
|
"Elasticsearch fun with Pancake and Godzilla",
|
|
|
Tokenization.Truncate.FIRST
|
|
|
);
|
|
|
- assertArrayEquals(new int[] { 0, 1, 3, 18, 17 }, tokenization.getTokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 3, 18, 17 }, tokenization.tokenIds());
|
|
|
}
|
|
|
|
|
|
try (
|
|
@@ -120,16 +120,16 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
"Elasticsearch fun with Pancake and Godzilla",
|
|
|
Tokenization.Truncate.FIRST
|
|
|
);
|
|
|
- assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap());
|
|
|
+ assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
|
public void testTokenizeAppendSpecialTokens() {
|
|
|
try (BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, Tokenization.createDefault()).build()) {
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
- assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
+ assertArrayEquals(new int[] { 12, 0, 1, 3, 13 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { -1, 0, 0, 1, -1 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -143,13 +143,13 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
.build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize(
|
|
|
"Elasticsearch " + specialToken + " fun",
|
|
|
Tokenization.Truncate.NONE
|
|
|
);
|
|
|
- assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", specialToken, "fun"));
|
|
|
- assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.getTokenMap());
|
|
|
+ assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", specialToken, "fun"));
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 15, 3 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1, 2 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -161,13 +161,13 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
).setDoLowerCase(false).setWithSpecialTokens(false).build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
- assertArrayEquals(new int[] { 3, 2 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
+ assertArrayEquals(new int[] { 3, 2 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 1 }, tokenization.tokenMap());
|
|
|
|
|
|
tokenization = tokenizer.tokenize("elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
- assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
|
|
|
}
|
|
|
|
|
|
try (
|
|
@@ -177,9 +177,9 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
).setDoLowerCase(true).setWithSpecialTokens(false).build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
- assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch fun", Tokenization.Truncate.NONE);
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 2 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -189,14 +189,14 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
.setWithSpecialTokens(false)
|
|
|
.build()
|
|
|
) {
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE);
|
|
|
- assertThat(tokenStrings(tokenization.getTokens()), contains("Elastic", "##search", ",", "fun", "."));
|
|
|
- assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("Elasticsearch, fun.", Tokenization.Truncate.NONE);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens()), contains("Elastic", "##search", ",", "fun", "."));
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 11, 3, 10 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1, 2, 3 }, tokenization.tokenMap());
|
|
|
|
|
|
tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].", Tokenization.Truncate.NONE);
|
|
|
- assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap());
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 11, 3, 14, 10 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -224,20 +224,20 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
).setWithSpecialTokens(true).setNeverSplit(Set.of("[MASK]")).build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE);
|
|
|
- assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "[MASK]", "-", "ta", "##stic", "!"));
|
|
|
- assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize("This is [MASK]-tastic!", Tokenization.Truncate.NONE);
|
|
|
+ assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "[MASK]", "-", "ta", "##stic", "!"));
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 2, 3, 4, 6, 7, 8, 9 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 4, 5, -1 }, tokenization.tokenMap());
|
|
|
|
|
|
tokenization = tokenizer.tokenize("This is sub~[MASK]!", Tokenization.Truncate.NONE);
|
|
|
- assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", "~", "[MASK]", "!"));
|
|
|
- assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.getTokenMap());
|
|
|
+ assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "sub", "~", "[MASK]", "!"));
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 2, 10, 5, 3, 8, 9 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, -1 }, tokenization.tokenMap());
|
|
|
|
|
|
tokenization = tokenizer.tokenize("This is sub,[MASK].tastic!", Tokenization.Truncate.NONE);
|
|
|
- assertThat(tokenStrings(tokenization.getTokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!"));
|
|
|
- assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.getTokenMap());
|
|
|
+ assertThat(tokenStrings(tokenization.tokens()), contains("This", "is", "sub", ",", "[MASK]", ".", "ta", "##stic", "!"));
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 2, 10, 11, 3, 12, 6, 7, 8, 9 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { -1, 0, 1, 2, 3, 4, 5, 6, 6, 7, -1 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -257,23 +257,23 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
tokenizer.tokenize("Godzilla Pancake red car day", Tokenization.Truncate.NONE)
|
|
|
)
|
|
|
);
|
|
|
- assertThat(tr.getTokenizations(), hasSize(4));
|
|
|
+ assertThat(tr.getTokens(), hasSize(4));
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
|
- assertArrayEquals(new int[] { 0, 1 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0 }, tokenization.getTokenMap());
|
|
|
+ TokenizationResult.Tokens tokenization = tr.getTokenization(0);
|
|
|
+ assertArrayEquals(new int[] { 0, 1 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0 }, tokenization.tokenMap());
|
|
|
|
|
|
- tokenization = tr.getTokenizations().get(1);
|
|
|
- assertArrayEquals(new int[] { 4, 5, 6, 7 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.getTokenMap());
|
|
|
+ tokenization = tr.getTokenization(1);
|
|
|
+ assertArrayEquals(new int[] { 4, 5, 6, 7 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 1, 2, 3 }, tokenization.tokenMap());
|
|
|
|
|
|
- tokenization = tr.getTokenizations().get(2);
|
|
|
- assertArrayEquals(new int[] { 8, 9, 16 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.getTokenMap());
|
|
|
+ tokenization = tr.getTokenization(2);
|
|
|
+ assertArrayEquals(new int[] { 8, 9, 16 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1 }, tokenization.tokenMap());
|
|
|
|
|
|
- tokenization = tr.getTokenizations().get(3);
|
|
|
- assertArrayEquals(new int[] { 8, 9, 17, 6, 7, 16 }, tokenization.getTokenIds());
|
|
|
- assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.getTokenMap());
|
|
|
+ tokenization = tr.getTokenization(3);
|
|
|
+ assertArrayEquals(new int[] { 8, 9, 17, 6, 7, 16 }, tokenization.tokenIds());
|
|
|
+ assertArrayEquals(new int[] { 0, 0, 1, 2, 3, 4 }, tokenization.tokenMap());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -284,13 +284,13 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
.setWithSpecialTokens(true)
|
|
|
.build()
|
|
|
) {
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize(
|
|
|
"Elasticsearch is fun",
|
|
|
"Godzilla my little red car",
|
|
|
Tokenization.Truncate.NONE
|
|
|
);
|
|
|
|
|
|
- var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
|
|
|
+ var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
|
|
|
assertThat(
|
|
|
tokenStream,
|
|
|
contains(
|
|
@@ -309,7 +309,7 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
BertTokenizer.SEPARATOR_TOKEN
|
|
|
)
|
|
|
);
|
|
|
- assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.getTokenIds());
|
|
|
+ assertArrayEquals(new int[] { 12, 0, 1, 2, 3, 13, 8, 9, 4, 5, 6, 7, 13 }, tokenization.tokenIds());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -321,13 +321,13 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
).build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize(
|
|
|
"Elasticsearch is fun",
|
|
|
"Godzilla my little red car",
|
|
|
Tokenization.Truncate.FIRST
|
|
|
);
|
|
|
|
|
|
- var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
|
|
|
+ var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
|
|
|
assertThat(
|
|
|
tokenStream,
|
|
|
contains(
|
|
@@ -359,12 +359,12 @@ public class BertTokenizerTests extends ESTestCase {
|
|
|
).build()
|
|
|
) {
|
|
|
|
|
|
- TokenizationResult.Tokenization tokenization = tokenizer.tokenize(
|
|
|
+ TokenizationResult.Tokens tokenization = tokenizer.tokenize(
|
|
|
"Elasticsearch is fun",
|
|
|
"Godzilla my little red car",
|
|
|
Tokenization.Truncate.SECOND
|
|
|
);
|
|
|
- var tokenStream = Arrays.stream(tokenization.getTokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
|
|
|
+ var tokenStream = Arrays.stream(tokenization.tokenIds()).mapToObj(TEST_CASED_VOCAB::get).collect(Collectors.toList());
|
|
|
assertThat(
|
|
|
tokenStream,
|
|
|
contains(
|