|
@@ -20,6 +20,8 @@ import java.util.Set;
|
|
|
import java.util.SortedMap;
|
|
|
import java.util.TreeMap;
|
|
|
import java.util.function.Function;
|
|
|
+import java.util.stream.IntStream;
|
|
|
+import java.util.stream.Stream;
|
|
|
|
|
|
/**
|
|
|
* Performs basic tokenization and normalization of input text
|
|
@@ -41,7 +43,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
|
|
|
public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
|
|
|
|
|
|
- private final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
|
|
|
+ private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
|
|
|
|
|
|
private final WordPieceTokenizer wordPieceTokenizer;
|
|
|
private final List<String> originalVocab;
|
|
@@ -50,10 +52,17 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
private final boolean doLowerCase;
|
|
|
private final boolean doTokenizeCjKChars;
|
|
|
private final boolean doStripAccents;
|
|
|
- private final boolean withSpecialTokens;
|
|
|
+ protected final boolean withSpecialTokens;
|
|
|
private final Set<String> neverSplit;
|
|
|
private final int maxSequenceLength;
|
|
|
private final NlpTask.RequestBuilder requestBuilder;
|
|
|
+ private final String sepToken;
|
|
|
+ protected final int sepTokenId;
|
|
|
+ private final String clsToken;
|
|
|
+ private final int clsTokenId;
|
|
|
+ private final String padToken;
|
|
|
+ private final String maskToken;
|
|
|
+ private final String unknownToken;
|
|
|
|
|
|
protected BertTokenizer(
|
|
|
List<String> originalVocab,
|
|
@@ -63,37 +72,97 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
boolean doStripAccents,
|
|
|
boolean withSpecialTokens,
|
|
|
int maxSequenceLength,
|
|
|
- Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
|
|
|
+ Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
|
|
|
Set<String> neverSplit
|
|
|
) {
|
|
|
- wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
|
|
|
+ this(
|
|
|
+ originalVocab,
|
|
|
+ vocab,
|
|
|
+ doLowerCase,
|
|
|
+ doTokenizeCjKChars,
|
|
|
+ doStripAccents,
|
|
|
+ withSpecialTokens,
|
|
|
+ maxSequenceLength,
|
|
|
+ requestBuilderFactory,
|
|
|
+ Sets.union(neverSplit, NEVER_SPLIT),
|
|
|
+ SEPARATOR_TOKEN,
|
|
|
+ CLASS_TOKEN,
|
|
|
+ PAD_TOKEN,
|
|
|
+ MASK_TOKEN,
|
|
|
+ UNKNOWN_TOKEN
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ protected BertTokenizer(
|
|
|
+ List<String> originalVocab,
|
|
|
+ SortedMap<String, Integer> vocab,
|
|
|
+ boolean doLowerCase,
|
|
|
+ boolean doTokenizeCjKChars,
|
|
|
+ boolean doStripAccents,
|
|
|
+ boolean withSpecialTokens,
|
|
|
+ int maxSequenceLength,
|
|
|
+ Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
|
|
|
+ Set<String> neverSplit,
|
|
|
+ String sepToken,
|
|
|
+ String clsToken,
|
|
|
+ String padToken,
|
|
|
+ String maskToken,
|
|
|
+ String unknownToken
|
|
|
+ ) {
|
|
|
+ wordPieceTokenizer = new WordPieceTokenizer(vocab, unknownToken, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
|
|
|
this.originalVocab = originalVocab;
|
|
|
this.vocab = vocab;
|
|
|
this.doLowerCase = doLowerCase;
|
|
|
this.doTokenizeCjKChars = doTokenizeCjKChars;
|
|
|
this.doStripAccents = doStripAccents;
|
|
|
this.withSpecialTokens = withSpecialTokens;
|
|
|
- this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
|
|
|
+ this.neverSplit = neverSplit;
|
|
|
this.maxSequenceLength = maxSequenceLength;
|
|
|
this.requestBuilder = requestBuilderFactory.apply(this);
|
|
|
- if (vocab.containsKey(UNKNOWN_TOKEN) == false) {
|
|
|
- throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", UNKNOWN_TOKEN);
|
|
|
+ if (vocab.containsKey(unknownToken) == false) {
|
|
|
+ throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", unknownToken);
|
|
|
}
|
|
|
- if (vocab.containsKey(PAD_TOKEN) == false) {
|
|
|
- throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", PAD_TOKEN);
|
|
|
+ if (vocab.containsKey(padToken) == false) {
|
|
|
+ throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", padToken);
|
|
|
}
|
|
|
|
|
|
if (withSpecialTokens) {
|
|
|
- Set<String> missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet());
|
|
|
+ Set<String> missingSpecialTokens = Sets.difference(Set.of(sepToken, clsToken), vocab.keySet());
|
|
|
if (missingSpecialTokens.isEmpty() == false) {
|
|
|
throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", missingSpecialTokens);
|
|
|
}
|
|
|
+ this.sepTokenId = vocab.get(sepToken);
|
|
|
+ this.clsTokenId = vocab.get(clsToken);
|
|
|
+ } else {
|
|
|
+ this.sepTokenId = -1;
|
|
|
+ this.clsTokenId = -1;
|
|
|
}
|
|
|
+ this.sepToken = sepToken;
|
|
|
+ this.clsToken = clsToken;
|
|
|
+ this.padToken = padToken;
|
|
|
+ this.maskToken = maskToken;
|
|
|
+ this.unknownToken = unknownToken;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getSepToken() {
|
|
|
+ return sepToken;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getClsToken() {
|
|
|
+ return clsToken;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getPadToken() {
|
|
|
+ return padToken;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getUnknownToken() {
|
|
|
+ return unknownToken;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public OptionalInt getPadTokenId() {
|
|
|
- Integer pad = vocab.get(PAD_TOKEN);
|
|
|
+ Integer pad = vocab.get(this.padToken);
|
|
|
if (pad != null) {
|
|
|
return OptionalInt.of(pad);
|
|
|
} else {
|
|
@@ -103,7 +172,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
|
|
|
@Override
|
|
|
public OptionalInt getMaskTokenId() {
|
|
|
- Integer pad = vocab.get(MASK_TOKEN);
|
|
|
+ Integer pad = vocab.get(this.maskToken);
|
|
|
if (pad != null) {
|
|
|
return OptionalInt.of(pad);
|
|
|
} else {
|
|
@@ -113,7 +182,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
|
|
|
@Override
|
|
|
public String getMaskToken() {
|
|
|
- return MASK_TOKEN;
|
|
|
+ return maskToken;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -150,6 +219,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
case SECOND:
|
|
|
isTruncated = true;
|
|
|
wordPieceTokenIds = wordPieceTokenIds.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
|
|
|
+ tokenPositionMap = tokenPositionMap.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
|
|
|
break;
|
|
|
case NONE:
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
@@ -158,31 +228,16 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
maxSequenceLength
|
|
|
);
|
|
|
}
|
|
|
- numTokens = maxSequenceLength;
|
|
|
- }
|
|
|
-
|
|
|
- int[] tokenIds = new int[numTokens];
|
|
|
- int[] tokenMap = new int[numTokens];
|
|
|
-
|
|
|
- if (withSpecialTokens) {
|
|
|
- tokenIds[0] = vocab.get(CLASS_TOKEN);
|
|
|
- tokenMap[0] = SPECIAL_TOKEN_POSITION;
|
|
|
- }
|
|
|
-
|
|
|
- int i = withSpecialTokens ? 1 : 0;
|
|
|
- final int decrementHandler = withSpecialTokens ? 1 : 0;
|
|
|
- for (var tokenId : wordPieceTokenIds) {
|
|
|
- tokenIds[i] = tokenId;
|
|
|
- tokenMap[i] = tokenPositionMap.get(i - decrementHandler);
|
|
|
- i++;
|
|
|
- }
|
|
|
-
|
|
|
- if (withSpecialTokens) {
|
|
|
- tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
|
|
|
- tokenMap[i] = SPECIAL_TOKEN_POSITION;
|
|
|
}
|
|
|
-
|
|
|
- return new TokenizationResult.Tokenization(seq, innerResult.tokens, isTruncated, tokenIds, tokenMap);
|
|
|
+ BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(wordPieceTokenIds, tokenPositionMap)
|
|
|
+ .addEndTokensIfNecessary();
|
|
|
+ return new TokenizationResult.Tokenization(
|
|
|
+ seq,
|
|
|
+ innerResult.tokens,
|
|
|
+ isTruncated,
|
|
|
+ bertTokenizationBuilder.buildIds(),
|
|
|
+ bertTokenizationBuilder.buildMap()
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -196,39 +251,47 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
if (withSpecialTokens == false) {
|
|
|
throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
|
|
|
}
|
|
|
- // [CLS] seq1 [SEP] seq2 [SEP]
|
|
|
- int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + 3;
|
|
|
+ int extraTokens = getNumExtraTokensForSeqPair();
|
|
|
+ int numTokens = wordPieceTokenIdsSeq1.size() + wordPieceTokenIdsSeq2.size() + extraTokens;
|
|
|
|
|
|
boolean isTruncated = false;
|
|
|
if (numTokens > maxSequenceLength) {
|
|
|
switch (truncate) {
|
|
|
case FIRST:
|
|
|
isTruncated = true;
|
|
|
- if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - 3) {
|
|
|
+ if (wordPieceTokenIdsSeq2.size() > maxSequenceLength - extraTokens) {
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
|
"Attempting truncation [{}] but input is too large for the second sequence. "
|
|
|
+ "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
|
|
|
+ "when taking special tokens into account",
|
|
|
truncate.toString(),
|
|
|
wordPieceTokenIdsSeq2.size(),
|
|
|
- maxSequenceLength - 3
|
|
|
+ maxSequenceLength - extraTokens
|
|
|
);
|
|
|
}
|
|
|
- wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq2.size());
|
|
|
+ wordPieceTokenIdsSeq1 = wordPieceTokenIdsSeq1.subList(
|
|
|
+ 0,
|
|
|
+ maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size()
|
|
|
+ );
|
|
|
+ tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, maxSequenceLength - extraTokens - wordPieceTokenIdsSeq2.size());
|
|
|
break;
|
|
|
case SECOND:
|
|
|
isTruncated = true;
|
|
|
- if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - 3) {
|
|
|
+ if (wordPieceTokenIdsSeq1.size() > maxSequenceLength - extraTokens) {
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
|
"Attempting truncation [{}] but input is too large for the first sequence. "
|
|
|
+ "The tokenized input length [{}] exceeds the maximum sequence length [{}], "
|
|
|
+ "when taking special tokens into account",
|
|
|
truncate.toString(),
|
|
|
wordPieceTokenIdsSeq1.size(),
|
|
|
- maxSequenceLength - 3
|
|
|
+ maxSequenceLength - extraTokens
|
|
|
);
|
|
|
}
|
|
|
- wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(0, maxSequenceLength - 3 - wordPieceTokenIdsSeq1.size());
|
|
|
+ wordPieceTokenIdsSeq2 = wordPieceTokenIdsSeq2.subList(
|
|
|
+ 0,
|
|
|
+ maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size()
|
|
|
+ );
|
|
|
+ tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength - extraTokens - wordPieceTokenIdsSeq1.size());
|
|
|
break;
|
|
|
case NONE:
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
@@ -237,38 +300,27 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
maxSequenceLength
|
|
|
);
|
|
|
}
|
|
|
- numTokens = maxSequenceLength;
|
|
|
- }
|
|
|
- int[] tokenIds = new int[numTokens];
|
|
|
- int[] tokenMap = new int[numTokens];
|
|
|
-
|
|
|
- tokenIds[0] = vocab.get(CLASS_TOKEN);
|
|
|
- tokenMap[0] = SPECIAL_TOKEN_POSITION;
|
|
|
-
|
|
|
- int i = 1;
|
|
|
- for (var tokenId : wordPieceTokenIdsSeq1) {
|
|
|
- tokenIds[i] = tokenId;
|
|
|
- tokenMap[i] = tokenPositionMapSeq1.get(i - 1);
|
|
|
- i++;
|
|
|
}
|
|
|
- tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
|
|
|
- tokenMap[i] = SPECIAL_TOKEN_POSITION;
|
|
|
- ++i;
|
|
|
-
|
|
|
- int j = 0;
|
|
|
- for (var tokenId : wordPieceTokenIdsSeq2) {
|
|
|
- tokenIds[i] = tokenId;
|
|
|
- tokenMap[i] = tokenPositionMapSeq2.get(j);
|
|
|
- i++;
|
|
|
- j++;
|
|
|
- }
|
|
|
-
|
|
|
- tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
|
|
|
- tokenMap[i] = SPECIAL_TOKEN_POSITION;
|
|
|
-
|
|
|
+ BertTokenizationBuilder bertTokenizationBuilder = bertTokenizationBuilder().addTokens(wordPieceTokenIdsSeq1, tokenPositionMapSeq1)
|
|
|
+ .addTokens(wordPieceTokenIdsSeq2, tokenPositionMapSeq2)
|
|
|
+ .addEndTokensIfNecessary();
|
|
|
List<DelimitedToken> tokens = new ArrayList<>(innerResultSeq1.tokens);
|
|
|
tokens.addAll(innerResultSeq2.tokens);
|
|
|
- return new TokenizationResult.Tokenization(seq1 + seq2, tokens, isTruncated, tokenIds, tokenMap);
|
|
|
+ return new TokenizationResult.Tokenization(
|
|
|
+ seq1 + seq2,
|
|
|
+ tokens,
|
|
|
+ isTruncated,
|
|
|
+ bertTokenizationBuilder.buildIds(),
|
|
|
+ bertTokenizationBuilder.buildMap()
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ protected BertTokenizationBuilder bertTokenizationBuilder() {
|
|
|
+ return new BertTokenizationBuilder();
|
|
|
+ }
|
|
|
+
|
|
|
+ protected int getNumExtraTokensForSeqPair() {
|
|
|
+ return 3;
|
|
|
}
|
|
|
|
|
|
private InnerTokenization innerTokenize(String seq) {
|
|
@@ -280,7 +332,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
for (int sourceIndex = 0; sourceIndex < tokenSequences.size(); sourceIndex++) {
|
|
|
String token = tokenSequences.get(sourceIndex).getToken();
|
|
|
if (neverSplit.contains(token)) {
|
|
|
- wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN)));
|
|
|
+ wordPieceTokens.add(vocab.getOrDefault(token, vocab.get(unknownToken)));
|
|
|
tokenPositionMap.add(sourceIndex);
|
|
|
} else {
|
|
|
List<Integer> tokens = wordPieceTokenizer.tokenize(tokenSequences.get(sourceIndex));
|
|
@@ -319,6 +371,48 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
return new Builder(vocab, tokenization);
|
|
|
}
|
|
|
|
|
|
+ protected class BertTokenizationBuilder {
|
|
|
+ Stream.Builder<IntStream> tokenIds;
|
|
|
+ Stream.Builder<IntStream> tokenMap;
|
|
|
+ int numSeq;
|
|
|
+
|
|
|
+ BertTokenizationBuilder() {
|
|
|
+ tokenIds = Stream.builder();
|
|
|
+ tokenMap = Stream.builder();
|
|
|
+ if (withSpecialTokens) {
|
|
|
+ tokenIds.add(IntStream.of(clsTokenId));
|
|
|
+ tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ BertTokenizationBuilder addTokens(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
|
|
|
+ if (numSeq > 0 && withSpecialTokens) {
|
|
|
+ tokenIds.add(IntStream.of(sepTokenId));
|
|
|
+ tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
|
|
|
+ }
|
|
|
+ tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
|
|
|
+ tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
|
|
|
+ numSeq++;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ BertTokenizationBuilder addEndTokensIfNecessary() {
|
|
|
+ if (withSpecialTokens) {
|
|
|
+ tokenIds.add(IntStream.of(sepTokenId));
|
|
|
+ tokenMap.add(IntStream.of(SPECIAL_TOKEN_POSITION));
|
|
|
+ }
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ int[] buildIds() {
|
|
|
+ return tokenIds.build().flatMapToInt(Function.identity()).toArray();
|
|
|
+ }
|
|
|
+
|
|
|
+ int[] buildMap() {
|
|
|
+ return tokenMap.build().flatMapToInt(Function.identity()).toArray();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public static class Builder {
|
|
|
|
|
|
protected final List<String> originalVocab;
|
|
@@ -329,7 +423,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
protected int maxSequenceLength;
|
|
|
protected Boolean doStripAccents = null;
|
|
|
protected Set<String> neverSplit;
|
|
|
- protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
|
|
|
+ protected Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
|
|
|
|
|
|
protected Builder(List<String> vocab, Tokenization tokenization) {
|
|
|
this.originalVocab = vocab;
|
|
@@ -382,7 +476,7 @@ public class BertTokenizer implements NlpTokenizer {
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
- public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
|
|
|
+ public Builder setRequestBuilderFactory(Function<NlpTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
|
|
|
this.requestBuilderFactory = requestBuilderFactory;
|
|
|
return this;
|
|
|
}
|