|
@@ -7,6 +7,10 @@
|
|
|
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
|
|
|
|
|
|
import org.elasticsearch.common.util.set.Sets;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
|
|
|
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
+import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
|
|
|
+import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
@@ -16,6 +20,7 @@ import java.util.List;
|
|
|
import java.util.Set;
|
|
|
import java.util.SortedMap;
|
|
|
import java.util.TreeMap;
|
|
|
+import java.util.function.Function;
|
|
|
|
|
|
/**
|
|
|
* Performs basic tokenization and normalization of input text
|
|
@@ -25,7 +30,7 @@ import java.util.TreeMap;
|
|
|
* Derived from
|
|
|
* https://github.com/huggingface/transformers/blob/ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532/src/transformers/tokenization_bert.py
|
|
|
*/
|
|
|
-public class BertTokenizer {
|
|
|
+public class BertTokenizer implements NlpTokenizer {
|
|
|
|
|
|
public static final String UNKNOWN_TOKEN = "[UNK]";
|
|
|
public static final String SEPARATOR_TOKEN = "[SEP]";
|
|
@@ -48,15 +53,18 @@ public class BertTokenizer {
|
|
|
private final boolean doStripAccents;
|
|
|
private final boolean withSpecialTokens;
|
|
|
private final Set<String> neverSplit;
|
|
|
-
|
|
|
- private BertTokenizer(
|
|
|
- List<String> originalVocab,
|
|
|
- SortedMap<String, Integer> vocab,
|
|
|
- boolean doLowerCase,
|
|
|
- boolean doTokenizeCjKChars,
|
|
|
- boolean doStripAccents,
|
|
|
- boolean withSpecialTokens,
|
|
|
- Set<String> neverSplit) {
|
|
|
+ private final int maxSequenceLength;
|
|
|
+ private final NlpTask.RequestBuilder requestBuilder;
|
|
|
+
|
|
|
+ protected BertTokenizer(List<String> originalVocab,
|
|
|
+ SortedMap<String, Integer> vocab,
|
|
|
+ boolean doLowerCase,
|
|
|
+ boolean doTokenizeCjKChars,
|
|
|
+ boolean doStripAccents,
|
|
|
+ boolean withSpecialTokens,
|
|
|
+ int maxSequenceLength,
|
|
|
+ Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
|
|
|
+ Set<String> neverSplit) {
|
|
|
wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
|
|
|
this.originalVocab = originalVocab;
|
|
|
this.vocab = vocab;
|
|
@@ -65,6 +73,8 @@ public class BertTokenizer {
|
|
|
this.doStripAccents = doStripAccents;
|
|
|
this.withSpecialTokens = withSpecialTokens;
|
|
|
this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
|
|
|
+ this.maxSequenceLength = maxSequenceLength;
|
|
|
+ this.requestBuilder = requestBuilderFactory.apply(this);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -76,6 +86,7 @@ public class BertTokenizer {
|
|
|
* @param text Text to tokenize
|
|
|
* @return Tokenized text, token Ids and map
|
|
|
*/
|
|
|
+ @Override
|
|
|
public TokenizationResult tokenize(String text) {
|
|
|
BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
|
|
|
|
|
@@ -126,79 +137,48 @@ public class BertTokenizer {
|
|
|
tokenMap[i] = SPECIAL_TOKEN_POSITION;
|
|
|
}
|
|
|
|
|
|
- return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
|
|
|
- }
|
|
|
-
|
|
|
- public static class TokenizationResult {
|
|
|
-
|
|
|
- String input;
|
|
|
- List<String> vocab;
|
|
|
- private final List<String> tokens;
|
|
|
- private final int [] tokenIds;
|
|
|
- private final int [] tokenMap;
|
|
|
-
|
|
|
- public TokenizationResult(String input, List<String> vocab, List<String> tokens, int[] tokenIds, int[] tokenMap) {
|
|
|
- assert tokens.size() == tokenIds.length;
|
|
|
- assert tokenIds.length == tokenMap.length;
|
|
|
- this.input = input;
|
|
|
- this.vocab = vocab;
|
|
|
- this.tokens = tokens;
|
|
|
- this.tokenIds = tokenIds;
|
|
|
- this.tokenMap = tokenMap;
|
|
|
- }
|
|
|
-
|
|
|
- public String getFromVocab(int tokenId) {
|
|
|
- return vocab.get(tokenId);
|
|
|
+ if (tokenIds.length > maxSequenceLength) {
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
+ "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
|
|
|
+ tokenIds.length,
|
|
|
+ maxSequenceLength
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
- /**
|
|
|
- * The token strings from the tokenization process
|
|
|
- * @return A list of tokens
|
|
|
- */
|
|
|
- public List<String> getTokens() {
|
|
|
- return tokens;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * The integer values of the tokens in {@link #getTokens()}
|
|
|
- * @return A list of token Ids
|
|
|
- */
|
|
|
- public int[] getTokenIds() {
|
|
|
- return tokenIds;
|
|
|
- }
|
|
|
+ return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
|
|
|
+ }
|
|
|
|
|
|
- /**
|
|
|
- * Maps the token position to the position in the source text.
|
|
|
- * Source words may be divided into more than one token so more
|
|
|
- * than one token can map back to the source token
|
|
|
- * @return Map of source token to
|
|
|
- */
|
|
|
- public int[] getTokenMap() {
|
|
|
- return tokenMap;
|
|
|
- }
|
|
|
+ @Override
|
|
|
+ public NlpTask.RequestBuilder requestBuilder() {
|
|
|
+ return requestBuilder;
|
|
|
+ }
|
|
|
|
|
|
- public String getInput() {
|
|
|
- return input;
|
|
|
- }
|
|
|
+ public int getMaxSequenceLength() {
|
|
|
+ return maxSequenceLength;
|
|
|
}
|
|
|
|
|
|
- public static Builder builder(List<String> vocab) {
|
|
|
- return new Builder(vocab);
|
|
|
+ public static Builder builder(List<String> vocab, Tokenization tokenization) {
|
|
|
+ return new Builder(vocab, tokenization);
|
|
|
}
|
|
|
|
|
|
public static class Builder {
|
|
|
|
|
|
- private final List<String> originalVocab;
|
|
|
- private final SortedMap<String, Integer> vocab;
|
|
|
- private boolean doLowerCase = false;
|
|
|
- private boolean doTokenizeCjKChars = true;
|
|
|
- private boolean withSpecialTokens = true;
|
|
|
- private Boolean doStripAccents = null;
|
|
|
- private Set<String> neverSplit;
|
|
|
-
|
|
|
- private Builder(List<String> vocab) {
|
|
|
+ protected final List<String> originalVocab;
|
|
|
+ protected final SortedMap<String, Integer> vocab;
|
|
|
+ protected boolean doLowerCase = false;
|
|
|
+ protected boolean doTokenizeCjKChars = true;
|
|
|
+ protected boolean withSpecialTokens = true;
|
|
|
+ protected int maxSequenceLength;
|
|
|
+ protected Boolean doStripAccents = null;
|
|
|
+ protected Set<String> neverSplit;
|
|
|
+ protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
|
|
|
+
|
|
|
+ protected Builder(List<String> vocab, Tokenization tokenization) {
|
|
|
this.originalVocab = vocab;
|
|
|
this.vocab = buildSortedVocab(vocab);
|
|
|
+ this.doLowerCase = tokenization.doLowerCase();
|
|
|
+ this.withSpecialTokens = tokenization.withSpecialTokens();
|
|
|
+ this.maxSequenceLength = tokenization.maxSequenceLength();
|
|
|
}
|
|
|
|
|
|
private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
|
|
@@ -229,6 +209,11 @@ public class BertTokenizer {
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
+ public Builder setMaxSequenceLength(int maxSequenceLength) {
|
|
|
+ this.maxSequenceLength = maxSequenceLength;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Include CLS and SEP tokens
|
|
|
* @param withSpecialTokens if true include CLS and SEP tokens
|
|
@@ -239,6 +224,11 @@ public class BertTokenizer {
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
+ public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
|
|
|
+ this.requestBuilderFactory = requestBuilderFactory;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
public BertTokenizer build() {
|
|
|
// if not set strip accents defaults to the value of doLowerCase
|
|
|
if (doStripAccents == null) {
|
|
@@ -249,7 +239,17 @@ public class BertTokenizer {
|
|
|
neverSplit = Collections.emptySet();
|
|
|
}
|
|
|
|
|
|
- return new BertTokenizer(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, withSpecialTokens, neverSplit);
|
|
|
+ return new BertTokenizer(
|
|
|
+ originalVocab,
|
|
|
+ vocab,
|
|
|
+ doLowerCase,
|
|
|
+ doTokenizeCjKChars,
|
|
|
+ doStripAccents,
|
|
|
+ withSpecialTokens,
|
|
|
+ maxSequenceLength,
|
|
|
+ requestBuilderFactory,
|
|
|
+ neverSplit
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
}
|