|  | @@ -7,6 +7,10 @@
 | 
	
		
			
				|  |  |  package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.util.set.Sets;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.util.ArrayList;
 | 
	
		
			
				|  |  |  import java.util.Arrays;
 | 
	
	
		
			
				|  | @@ -16,6 +20,7 @@ import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Set;
 | 
	
		
			
				|  |  |  import java.util.SortedMap;
 | 
	
		
			
				|  |  |  import java.util.TreeMap;
 | 
	
		
			
				|  |  | +import java.util.function.Function;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  /**
 | 
	
		
			
				|  |  |   * Performs basic tokenization and normalization of input text
 | 
	
	
		
			
				|  | @@ -25,7 +30,7 @@ import java.util.TreeMap;
 | 
	
		
			
				|  |  |   * Derived from
 | 
	
		
			
				|  |  |   * https://github.com/huggingface/transformers/blob/ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532/src/transformers/tokenization_bert.py
 | 
	
		
			
				|  |  |   */
 | 
	
		
			
				|  |  | -public class BertTokenizer {
 | 
	
		
			
				|  |  | +public class BertTokenizer implements NlpTokenizer {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public static final String UNKNOWN_TOKEN = "[UNK]";
 | 
	
		
			
				|  |  |      public static final String SEPARATOR_TOKEN = "[SEP]";
 | 
	
	
		
			
				|  | @@ -48,15 +53,18 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |      private final boolean doStripAccents;
 | 
	
		
			
				|  |  |      private final boolean withSpecialTokens;
 | 
	
		
			
				|  |  |      private final Set<String> neverSplit;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private BertTokenizer(
 | 
	
		
			
				|  |  | -                          List<String> originalVocab,
 | 
	
		
			
				|  |  | -                          SortedMap<String, Integer> vocab,
 | 
	
		
			
				|  |  | -                          boolean doLowerCase,
 | 
	
		
			
				|  |  | -                          boolean doTokenizeCjKChars,
 | 
	
		
			
				|  |  | -                          boolean doStripAccents,
 | 
	
		
			
				|  |  | -                          boolean withSpecialTokens,
 | 
	
		
			
				|  |  | -                          Set<String> neverSplit) {
 | 
	
		
			
				|  |  | +    private final int maxSequenceLength;
 | 
	
		
			
				|  |  | +    private final NlpTask.RequestBuilder requestBuilder;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    protected BertTokenizer(List<String> originalVocab,
 | 
	
		
			
				|  |  | +                            SortedMap<String, Integer> vocab,
 | 
	
		
			
				|  |  | +                            boolean doLowerCase,
 | 
	
		
			
				|  |  | +                            boolean doTokenizeCjKChars,
 | 
	
		
			
				|  |  | +                            boolean doStripAccents,
 | 
	
		
			
				|  |  | +                            boolean withSpecialTokens,
 | 
	
		
			
				|  |  | +                            int maxSequenceLength,
 | 
	
		
			
				|  |  | +                            Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory,
 | 
	
		
			
				|  |  | +                            Set<String> neverSplit) {
 | 
	
		
			
				|  |  |          wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, DEFAULT_MAX_INPUT_CHARS_PER_WORD);
 | 
	
		
			
				|  |  |          this.originalVocab = originalVocab;
 | 
	
		
			
				|  |  |          this.vocab = vocab;
 | 
	
	
		
			
				|  | @@ -65,6 +73,8 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |          this.doStripAccents = doStripAccents;
 | 
	
		
			
				|  |  |          this.withSpecialTokens = withSpecialTokens;
 | 
	
		
			
				|  |  |          this.neverSplit = Sets.union(neverSplit, NEVER_SPLIT);
 | 
	
		
			
				|  |  | +        this.maxSequenceLength = maxSequenceLength;
 | 
	
		
			
				|  |  | +        this.requestBuilder = requestBuilderFactory.apply(this);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      /**
 | 
	
	
		
			
				|  | @@ -76,6 +86,7 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |       * @param text Text to tokenize
 | 
	
		
			
				|  |  |       * @return Tokenized text, token Ids and map
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  |      public TokenizationResult tokenize(String text) {
 | 
	
		
			
				|  |  |          BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -126,79 +137,48 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |              tokenMap[i] = SPECIAL_TOKEN_POSITION;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    public static class TokenizationResult {
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        String input;
 | 
	
		
			
				|  |  | -        List<String> vocab;
 | 
	
		
			
				|  |  | -        private final List<String> tokens;
 | 
	
		
			
				|  |  | -        private final int [] tokenIds;
 | 
	
		
			
				|  |  | -        private final int [] tokenMap;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        public TokenizationResult(String input, List<String> vocab, List<String> tokens, int[] tokenIds, int[] tokenMap) {
 | 
	
		
			
				|  |  | -            assert tokens.size() == tokenIds.length;
 | 
	
		
			
				|  |  | -            assert tokenIds.length == tokenMap.length;
 | 
	
		
			
				|  |  | -            this.input = input;
 | 
	
		
			
				|  |  | -            this.vocab = vocab;
 | 
	
		
			
				|  |  | -            this.tokens = tokens;
 | 
	
		
			
				|  |  | -            this.tokenIds = tokenIds;
 | 
	
		
			
				|  |  | -            this.tokenMap = tokenMap;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        public String getFromVocab(int tokenId) {
 | 
	
		
			
				|  |  | -            return vocab.get(tokenId);
 | 
	
		
			
				|  |  | +        if (tokenIds.length > maxSequenceLength) {
 | 
	
		
			
				|  |  | +            throw ExceptionsHelper.badRequestException(
 | 
	
		
			
				|  |  | +                "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
 | 
	
		
			
				|  |  | +                tokenIds.length,
 | 
	
		
			
				|  |  | +                maxSequenceLength
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        /**
 | 
	
		
			
				|  |  | -         * The token strings from the tokenization process
 | 
	
		
			
				|  |  | -         * @return A list of tokens
 | 
	
		
			
				|  |  | -         */
 | 
	
		
			
				|  |  | -        public List<String> getTokens() {
 | 
	
		
			
				|  |  | -            return tokens;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        /**
 | 
	
		
			
				|  |  | -         * The integer values of the tokens in {@link #getTokens()}
 | 
	
		
			
				|  |  | -         * @return A list of token Ids
 | 
	
		
			
				|  |  | -         */
 | 
	
		
			
				|  |  | -        public int[] getTokenIds() {
 | 
	
		
			
				|  |  | -            return tokenIds;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | +        return new TokenizationResult(text, originalVocab, tokens, tokenIds, tokenMap);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        /**
 | 
	
		
			
				|  |  | -         * Maps the token position to the position in the source text.
 | 
	
		
			
				|  |  | -         * Source words may be divided into more than one token so more
 | 
	
		
			
				|  |  | -         * than one token can map back to the source token
 | 
	
		
			
				|  |  | -         * @return Map of source token to
 | 
	
		
			
				|  |  | -         */
 | 
	
		
			
				|  |  | -        public int[] getTokenMap() {
 | 
	
		
			
				|  |  | -            return tokenMap;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    public NlpTask.RequestBuilder requestBuilder() {
 | 
	
		
			
				|  |  | +        return requestBuilder;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        public String getInput() {
 | 
	
		
			
				|  |  | -            return input;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | +    public int getMaxSequenceLength() {
 | 
	
		
			
				|  |  | +        return maxSequenceLength;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public static Builder builder(List<String> vocab) {
 | 
	
		
			
				|  |  | -        return new Builder(vocab);
 | 
	
		
			
				|  |  | +    public static Builder builder(List<String> vocab, Tokenization tokenization) {
 | 
	
		
			
				|  |  | +        return new Builder(vocab, tokenization);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public static class Builder {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        private final List<String> originalVocab;
 | 
	
		
			
				|  |  | -        private final SortedMap<String, Integer> vocab;
 | 
	
		
			
				|  |  | -        private boolean doLowerCase = false;
 | 
	
		
			
				|  |  | -        private boolean doTokenizeCjKChars = true;
 | 
	
		
			
				|  |  | -        private boolean withSpecialTokens = true;
 | 
	
		
			
				|  |  | -        private Boolean doStripAccents = null;
 | 
	
		
			
				|  |  | -        private Set<String> neverSplit;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        private Builder(List<String> vocab) {
 | 
	
		
			
				|  |  | +        protected final List<String> originalVocab;
 | 
	
		
			
				|  |  | +        protected final SortedMap<String, Integer> vocab;
 | 
	
		
			
				|  |  | +        protected boolean doLowerCase = false;
 | 
	
		
			
				|  |  | +        protected boolean doTokenizeCjKChars = true;
 | 
	
		
			
				|  |  | +        protected boolean withSpecialTokens = true;
 | 
	
		
			
				|  |  | +        protected int maxSequenceLength;
 | 
	
		
			
				|  |  | +        protected Boolean doStripAccents = null;
 | 
	
		
			
				|  |  | +        protected Set<String> neverSplit;
 | 
	
		
			
				|  |  | +        protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        protected Builder(List<String> vocab, Tokenization tokenization) {
 | 
	
		
			
				|  |  |              this.originalVocab = vocab;
 | 
	
		
			
				|  |  |              this.vocab = buildSortedVocab(vocab);
 | 
	
		
			
				|  |  | +            this.doLowerCase = tokenization.doLowerCase();
 | 
	
		
			
				|  |  | +            this.withSpecialTokens = tokenization.withSpecialTokens();
 | 
	
		
			
				|  |  | +            this.maxSequenceLength = tokenization.maxSequenceLength();
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
 | 
	
	
		
			
				|  | @@ -229,6 +209,11 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |              return this;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        public Builder setMaxSequenceLength(int maxSequenceLength) {
 | 
	
		
			
				|  |  | +            this.maxSequenceLength = maxSequenceLength;
 | 
	
		
			
				|  |  | +            return this;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          /**
 | 
	
		
			
				|  |  |           * Include CLS and SEP tokens
 | 
	
		
			
				|  |  |           * @param withSpecialTokens if true include CLS and SEP tokens
 | 
	
	
		
			
				|  | @@ -239,6 +224,11 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |              return this;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
 | 
	
		
			
				|  |  | +            this.requestBuilderFactory = requestBuilderFactory;
 | 
	
		
			
				|  |  | +            return this;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          public BertTokenizer build() {
 | 
	
		
			
				|  |  |              // if not set strip accents defaults to the value of doLowerCase
 | 
	
		
			
				|  |  |              if (doStripAccents == null) {
 | 
	
	
		
			
				|  | @@ -249,7 +239,17 @@ public class BertTokenizer {
 | 
	
		
			
				|  |  |                  neverSplit = Collections.emptySet();
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            return new BertTokenizer(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, withSpecialTokens, neverSplit);
 | 
	
		
			
				|  |  | +            return new BertTokenizer(
 | 
	
		
			
				|  |  | +                originalVocab,
 | 
	
		
			
				|  |  | +                vocab,
 | 
	
		
			
				|  |  | +                doLowerCase,
 | 
	
		
			
				|  |  | +                doTokenizeCjKChars,
 | 
	
		
			
				|  |  | +                doStripAccents,
 | 
	
		
			
				|  |  | +                withSpecialTokens,
 | 
	
		
			
				|  |  | +                maxSequenceLength,
 | 
	
		
			
				|  |  | +                requestBuilderFactory,
 | 
	
		
			
				|  |  | +                neverSplit
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 |