|
@@ -8,6 +8,7 @@
|
|
|
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
|
|
|
|
|
|
import org.elasticsearch.core.Releasable;
|
|
|
+import org.elasticsearch.core.Strings;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
|
|
@@ -227,15 +228,16 @@ public abstract class NlpTokenizer implements Releasable {
|
|
|
* @return tokenization result for the sequence pair
|
|
|
*/
|
|
|
public List<TokenizationResult.Tokens> tokenize(String seq1, String seq2, Tokenization.Truncate truncate, int span, int sequenceId) {
|
|
|
+ if (isWithSpecialTokens() == false) {
|
|
|
+ throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
|
|
|
+ }
|
|
|
+
|
|
|
var innerResultSeq1 = innerTokenize(seq1);
|
|
|
List<? extends DelimitedToken.Encoded> tokenIdsSeq1 = innerResultSeq1.tokens;
|
|
|
List<Integer> tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap;
|
|
|
var innerResultSeq2 = innerTokenize(seq2);
|
|
|
List<? extends DelimitedToken.Encoded> tokenIdsSeq2 = innerResultSeq2.tokens;
|
|
|
List<Integer> tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap;
|
|
|
- if (isWithSpecialTokens() == false) {
|
|
|
- throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
|
|
|
- }
|
|
|
int extraTokens = getNumExtraTokensForSeqPair();
|
|
|
int numTokens = tokenIdsSeq1.size() + tokenIdsSeq2.size() + extraTokens;
|
|
|
|
|
@@ -296,6 +298,30 @@ public abstract class NlpTokenizer implements Releasable {
|
|
|
List<Integer> seq1TokenIds = tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList());
|
|
|
|
|
|
final int trueMaxSeqLength = maxSequenceLength() - extraTokens - tokenIdsSeq1.size();
|
|
|
+ if (trueMaxSeqLength <= 0) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ Strings.format(
|
|
|
+ "Unable to do sequence pair tokenization: the first sequence [%d tokens] "
|
|
|
+ + "is longer than the max sequence length [%d tokens]",
|
|
|
+ tokenIdsSeq1.size() + extraTokens,
|
|
|
+ maxSequenceLength()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ if (span > trueMaxSeqLength) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ Strings.format(
|
|
|
+ "Unable to do sequence pair tokenization: the combined first sequence and span length [%d + %d = %d tokens] "
|
|
|
+ + "is longer than the max sequence length [%d tokens]. Reduce the size of the [span] window.",
|
|
|
+ tokenIdsSeq1.size(),
|
|
|
+ span,
|
|
|
+ tokenIdsSeq1.size() + span,
|
|
|
+ maxSequenceLength()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
while (splitEndPos < tokenIdsSeq2.size()) {
|
|
|
splitEndPos = Math.min(splitStartPos + trueMaxSeqLength, tokenIdsSeq2.size());
|
|
|
// Make sure we do not end on a word
|