|
@@ -13,8 +13,10 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
|
|
|
|
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
|
|
+import java.util.List;
|
|
|
|
|
|
import static org.hamcrest.Matchers.contains;
|
|
import static org.hamcrest.Matchers.contains;
|
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
|
|
|
|
public class BertTokenizerTests extends ESTestCase {
|
|
public class BertTokenizerTests extends ESTestCase {
|
|
|
|
|
|
@@ -24,7 +26,8 @@ public class BertTokenizerTests extends ESTestCase {
|
|
new BertTokenization(null, false, null)
|
|
new BertTokenization(null, false, null)
|
|
).build();
|
|
).build();
|
|
|
|
|
|
- TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
|
|
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun"));
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun"));
|
|
assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
|
|
assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
|
|
@@ -36,7 +39,8 @@ public class BertTokenizerTests extends ESTestCase {
|
|
Tokenization.createDefault()
|
|
Tokenization.createDefault()
|
|
).build();
|
|
).build();
|
|
|
|
|
|
- TokenizationResult tokenization = tokenizer.tokenize("elasticsearch fun");
|
|
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("elasticsearch fun"));
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]"));
|
|
assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]"));
|
|
assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
|
|
assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
|
|
@@ -52,7 +56,8 @@ public class BertTokenizerTests extends ESTestCase {
|
|
.setWithSpecialTokens(false)
|
|
.setWithSpecialTokens(false)
|
|
.build();
|
|
.build();
|
|
|
|
|
|
- TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
|
|
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch " + specialToken + " fun"));
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun"));
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun"));
|
|
assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
|
|
assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
|
|
@@ -67,12 +72,14 @@ public class BertTokenizerTests extends ESTestCase {
|
|
.setWithSpecialTokens(false)
|
|
.setWithSpecialTokens(false)
|
|
.build();
|
|
.build();
|
|
|
|
|
|
- TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
|
|
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun"));
|
|
assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun"));
|
|
assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
|
|
assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
|
|
|
|
|
|
- tokenization = tokenizer.tokenize("elasticsearch fun");
|
|
|
|
|
|
+ tr = tokenizer.tokenize(List.of("elasticsearch fun"));
|
|
|
|
+ tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
|
|
assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
|
|
}
|
|
}
|
|
|
|
|
|
@@ -82,7 +89,8 @@ public class BertTokenizerTests extends ESTestCase {
|
|
.setWithSpecialTokens(false)
|
|
.setWithSpecialTokens(false)
|
|
.build();
|
|
.build();
|
|
|
|
|
|
- TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch fun");
|
|
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
|
|
assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -93,14 +101,54 @@ public class BertTokenizerTests extends ESTestCase {
|
|
Tokenization.createDefault()
|
|
Tokenization.createDefault()
|
|
).setWithSpecialTokens(false).build();
|
|
).setWithSpecialTokens(false).build();
|
|
|
|
|
|
- TokenizationResult tokenization = tokenizer.tokenize("Elasticsearch, fun.");
|
|
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch, fun."));
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "."));
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "."));
|
|
assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());
|
|
assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());
|
|
|
|
|
|
- tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].");
|
|
|
|
|
|
+ tr = tokenizer.tokenize(List.of("Elasticsearch, fun [MASK]."));
|
|
|
|
+ tokenization = tr.getTokenizations().get(0);
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "[MASK]", "."));
|
|
assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "[MASK]", "."));
|
|
assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds());
|
|
assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
|
|
assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ public void testBatchInput() {
|
|
|
|
+ BertTokenizer tokenizer = BertTokenizer.builder(
|
|
|
|
+ Arrays.asList("Elastic", "##search", "fun",
|
|
|
|
+ "Pancake", "day",
|
|
|
|
+ "my", "little", "red", "car",
|
|
|
|
+ "God", "##zilla"
|
|
|
|
+ ),
|
|
|
|
+ new BertTokenization(null, false, null)
|
|
|
|
+ ).build();
|
|
|
|
+
|
|
|
|
+ TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch",
|
|
|
|
+ "my little red car",
|
|
|
|
+ "Godzilla day",
|
|
|
|
+ "Godzilla Pancake red car day"
|
|
|
|
+ ));
|
|
|
|
+ assertThat(tr.getTokenizations(), hasSize(4));
|
|
|
|
+
|
|
|
|
+ TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
|
|
|
|
+ assertThat(tokenization.getTokens(), contains("Elastic", "##search"));
|
|
|
|
+ assertArrayEquals(new int[] {0, 1}, tokenization.getTokenIds());
|
|
|
|
+ assertArrayEquals(new int[] {0, 0}, tokenization.getTokenMap());
|
|
|
|
+
|
|
|
|
+ tokenization = tr.getTokenizations().get(1);
|
|
|
|
+ assertThat(tokenization.getTokens(), contains("my", "little", "red", "car"));
|
|
|
|
+ assertArrayEquals(new int[] {5, 6, 7, 8}, tokenization.getTokenIds());
|
|
|
|
+ assertArrayEquals(new int[] {0, 1, 2, 3}, tokenization.getTokenMap());
|
|
|
|
+
|
|
|
|
+ tokenization = tr.getTokenizations().get(2);
|
|
|
|
+ assertThat(tokenization.getTokens(), contains("God", "##zilla", "day"));
|
|
|
|
+ assertArrayEquals(new int[] {9, 10, 4}, tokenization.getTokenIds());
|
|
|
|
+ assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
|
|
|
|
+
|
|
|
|
+ tokenization = tr.getTokenizations().get(3);
|
|
|
|
+ assertThat(tokenization.getTokens(), contains("God", "##zilla", "Pancake", "red", "car", "day"));
|
|
|
|
+ assertArrayEquals(new int[] {9, 10, 3, 7, 8, 4}, tokenization.getTokenIds());
|
|
|
|
+ assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
|
|
|
|
+ }
|
|
}
|
|
}
|