瀏覽代碼

[ML] add zero_shot_classification task for BERT nlp models (#77799)

Zero-Shot classification allows for text classification tasks without a pre-trained collection of target labels.

This is achieved through models trained on the Multi-Genre Natural Language Inference (MNLI) dataset. This dataset pairs  text sequences with "entailment" clauses. An example could be:

"Throughout all of history, man kind has shown itself resourceful, yet astoundingly short-sighted" could have been paired with the entailment clauses: ["This example is history", "This example is sociology"...]. 

This training set combined with the attention and semantic knowledge in modern day NLP models (BERT, BART, etc.) affords a powerful tool for ad-hoc text classification.

See https://arxiv.org/abs/1909.00161 for a deeper explanation of the MNLI training and how zero-shot works. 

The zeroshot classification task is configured as follows:
```js
{
   // <snip> model configuration </snip>
  "inference_config" : {
    "zero_shot_classification": {
      "classification_labels": ["entailment", "neutral", "contradiction"], // <1>
      "labels": ["sad", "glad", "mad", "rad"], // <2>
      "multi_label": false, // <3>
      "hypothesis_template": "This example is {}.", // <4>
      "tokenization": { /*<snip> tokenization configuration </snip>*/}
    }
  }
}
```
* <1> For all zero_shot models, there returns 3 particular labels when classification the target sequence. "entailment" is the positive case, "neutral" the case where the sequence isn't positive or negative, and "contradiction" is the negative case
* <2> This is an optional parameter for the default zero_shot labels to attempt to classify
* <3> When returning the probabilities, should the results assume there is only one true label or multiple true labels
* <4> The hypothesis template when tokenizing the labels. When combining with `sad` the sequence looks like `This example is sad.`

For inference in a pipeline one may provide label updates:
```js
{
  //<snip> pipeline definition </snip>
  "processors": [
    //<snip> other processors </snip>
    {
      "inference": {
        // <snip> general configuration </snip>
        "inference_config": {
          "zero_shot_classification": {
             "labels": ["humanities", "science", "mathematics", "technology"], // <1>
             "multi_label": true // <2>
          }
        }
      }
    }
    //<snip> other processors </snip>
  ]
}
```
* <1> The `labels` we care about, these replace the default ones if they exist. 
* <2> Should the results allow multiple true labels

Similarly one may provide label changes against the `_infer` endpoint
```js
{
   "docs":[{ "text_field": "This is a very happy person"}],
   "inference_config":{"zero_shot_classification":{"labels": ["glad", "sad", "bad", "rad"], "multi_label": false}}
}
```
Benjamin Trent 4 年之前
父節點
當前提交
408489310c
共有 33 個文件被更改,包括 1493 次插入255 次删除
  1. 80 18
      docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc
  2. 54 3
      docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc
  3. 108 68
      docs/reference/ml/ml-shared.asciidoc
  4. 28 21
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java
  5. 18 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  6. 21 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java
  7. 245 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java
  8. 201 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java
  9. 24 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java
  10. 61 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java
  11. 134 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java
  12. 4 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java
  13. 9 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  14. 6 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java
  15. 16 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  16. 26 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java
  17. 7 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  18. 13 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  19. 6 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  20. 8 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  21. 4 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  22. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  23. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java
  24. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  25. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
  26. 194 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
  27. 110 52
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  28. 5 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  29. 13 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java
  30. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  31. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  32. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  33. 78 34
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

+ 80 - 18
docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc

@@ -27,7 +27,7 @@ Retrieves configuration information for a trained model.
 [[ml-get-trained-models-prereq]]
 == {api-prereq-title}
 
-Requires the `monitor_ml` cluster privilege. This privilege is included in the 
+Requires the `monitor_ml` cluster privilege. This privilege is included in the
 `machine_learning_user` built-in role.
 
 
@@ -71,9 +71,9 @@ default value is empty, indicating no optional fields are included. Valid
 options are:
  - `definition`: Includes the model definition.
  - `feature_importance_baseline`: Includes the baseline for {feat-imp} values.
- - `hyperparameters`: Includes the information about hyperparameters used to 
-    train the model. This information consists of the value, the absolute and 
-    relative importance of the hyperparameter as well as an indicator of whether 
+ - `hyperparameters`: Includes the information about hyperparameters used to
+    train the model. This information consists of the value, the absolute and
+    relative importance of the hyperparameter as well as an indicator of whether
     it was specified by the user or tuned during hyperparameter optimization.
  - `total_feature_importance`: Includes the total {feat-imp} for the training
    data set.
@@ -222,8 +222,8 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-ner]
 [%collapsible%open]
 ======
 `classification_labels`::::
-(Optional, string) 
-An array of classification labels. NER supports only 
+(Optional, string)
+An array of classification labels. NER supports only
 Inside-Outside-Beginning labels (IOB) and only persons, organizations, locations,
 and miscellaneous. For example:
 `["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]`.
@@ -338,7 +338,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-classific
 [%collapsible%open]
 ======
 `classification_labels`::::
-(Optional, string) 
+(Optional, string)
 An array of classification labels.
 
 `num_top_classes`::::
@@ -414,6 +414,68 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
 
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
+========
+=======
+`vocabulary`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-vocabulary]
++
+.Properties of vocabulary
+[%collapsible%open]
+=======
+`index`::::
+(Required, string)
+The index where the vocabulary is stored.
+=======
+======
+`zero_shot_classification`::::
+(Object, optional)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification]
++
+.Properties of zero_shot_classification inference
+[%collapsible%open]
+======
+`classification_labels`::::
+(Required, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-classification-labels]
+
+`hypothesis_template`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-hypothesis-template]
+
+`labels`::::
+(Optional, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-labels]
+
+`multi_label`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-multi-label]
+
+`tokenization`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
++
+.Properties of tokenization
+[%collapsible%open]
+=======
+`bert`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
++
+.Properties of bert
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
 `with_special_tokens`::::
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
@@ -456,7 +518,7 @@ provided.
 =====
 `index`:::
 (Required, object)
-Indicates that the model definition is stored in an index. It is required to be empty as 
+Indicates that the model definition is stored in an index. It is required to be empty as
 the index for storing model definitions is configured automatically.
 =====
 // End location
@@ -480,7 +542,7 @@ it is a single value. For {classanalysis}, there is a value for each class.
 
 `hyperparameters`:::
 (array)
-List of the available hyperparameters optimized during the 
+List of the available hyperparameters optimized during the
 `fine_parameter_tuning` phase as well as specified by the user.
 +
 .Properties of hyperparameters
@@ -488,10 +550,10 @@ List of the available hyperparameters optimized during the
 ======
 `absolute_importance`::::
 (double)
-A positive number showing how much the parameter influences the variation of the 
-{ml-docs}/dfa-regression-lossfunction.html[loss function]. For 
-hyperparameters with values that are not specified by the user but tuned during 
-hyperparameter optimization. 
+A positive number showing how much the parameter influences the variation of the
+{ml-docs}/dfa-regression-lossfunction.html[loss function]. For
+hyperparameters with values that are not specified by the user but tuned during
+hyperparameter optimization.
 
 `max_trees`::::
 (integer)
@@ -503,14 +565,14 @@ Name of the hyperparameter.
 
 `relative_importance`::::
 (double)
-A number between 0 and 1 showing the proportion of influence on the variation of 
-the loss function among all tuned hyperparameters. For hyperparameters with 
-values that are not specified by the user but tuned during hyperparameter 
+A number between 0 and 1 showing the proportion of influence on the variation of
+the loss function among all tuned hyperparameters. For hyperparameters with
+values that are not specified by the user but tuned during hyperparameter
 optimization.
 
 `supplied`::::
 (Boolean)
-Indicates if the hyperparameter is specified by the user (`true`) or optimized 
+Indicates if the hyperparameter is specified by the user (`true`) or optimized
 (`false`).
 
 `value`::::
@@ -602,7 +664,7 @@ Identifier for the trained model.
 `model_type`::
 (Optional, string)
 The created model type. By default the model type is `tree_ensemble`.
-Appropriate types are: 
+Appropriate types are:
 +
 --
 * `tree_ensemble`: The model definition is an ensemble model of decision trees.

+ 54 - 3
docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc

@@ -377,7 +377,7 @@ A human-readable description of the {infer} trained model.
 `inference_config`::
 (Required, object)
 The default configuration for inference. This can be: `regression`,
-`classification`, `fill_mask`, `ner`, `text_classification`, or `text_embedding`. 
+`classification`, `fill_mask`, `ner`, `text_classification`, `text_embedding` or `zero_shot_classification`.
 If `regression` or `classification`, it must match the `target_type` of the
 underlying `definition.trained_model`. If `fill_mask`, `ner`,
 `text_classification`, or `text_embedding`; the `model_type` must be `pytorch`.
@@ -457,7 +457,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-ner]
 [%collapsible%open]
 =====
 `classification_labels`::::
-(Optional, string) 
+(Optional, string)
 An array of classification labels. NER only supports Inside-Outside-Beginning labels (IOB)
 and only persons, organizations, locations, and miscellaneous.
 Example: ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]
@@ -614,6 +614,57 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
 
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
+=======
+======
+=====
+`zero_shot_classification`:::
+(Object, optional)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification]
++
+.Properties of zero_shot_classification inference
+[%collapsible%open]
+=====
+`classification_labels`::::
+(Required, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-classification-labels]
+
+`hypothesis_template`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-hypothesis-template]
+
+`labels`::::
+(Optional, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-labels]
+
+`multi_label`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-multi-label]
+
+`tokenization`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
++
+.Properties of tokenization
+[%collapsible%open]
+======
+`bert`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
++
+.Properties of bert
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
 `with_special_tokens`::::
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
@@ -660,7 +711,7 @@ An object map that contains metadata about the model.
 `model_type`::
 (Optional, string)
 The created model type. By default the model type is `tree_ensemble`.
-Appropriate types are: 
+Appropriate types are:
 +
 --
 * `tree_ensemble`: The model definition is an ensemble model of decision trees.

+ 108 - 68
docs/reference/ml/ml-shared.asciidoc

@@ -323,7 +323,7 @@ end::custom-preprocessor[]
 tag::custom-rules[]
 An array of custom rule objects, which enable you to customize the way detectors
 operate. For example, a rule may dictate to the detector conditions under which
-results should be skipped. {kib} refers to custom rules as _job rules_. For more 
+results should be skipped. {kib} refers to custom rules as _job rules_. For more
 examples, see
 {ml-docs}/ml-configuring-detector-custom-rules.html[Customizing detectors with custom rules].
 end::custom-rules[]
@@ -526,21 +526,21 @@ end::detector-index[]
 tag::dfas-alpha[]
 Advanced configuration option. {ml-cap} uses loss guided tree growing, which
 means that the decision trees grow where the regularized loss decreases most
-quickly. This parameter affects loss calculations by acting as a multiplier of 
-the tree depth. Higher alpha values result in shallower trees and faster 
-training times. By default, this value is calculated during hyperparameter 
-optimization. It must be greater than or equal to zero. 
+quickly. This parameter affects loss calculations by acting as a multiplier of
+the tree depth. Higher alpha values result in shallower trees and faster
+training times. By default, this value is calculated during hyperparameter
+optimization. It must be greater than or equal to zero.
 end::dfas-alpha[]
 
 tag::dfas-downsample-factor[]
-Advanced configuration option. Controls the fraction of data that is used to 
-compute the derivatives of the loss function for tree training. A small value 
-results in the use of a small fraction of the data. If this value is set to be 
-less than 1, accuracy typically improves. However, too small a value may result 
+Advanced configuration option. Controls the fraction of data that is used to
+compute the derivatives of the loss function for tree training. A small value
+results in the use of a small fraction of the data. If this value is set to be
+less than 1, accuracy typically improves. However, too small a value may result
 in poor convergence for the ensemble and so require more trees. For more
 information about shrinkage, refer to
 {wikipedia}/Gradient_boosting#Stochastic_gradient_boosting[this wiki article].
-By default, this value is calculated during hyperparameter optimization. It 
+By default, this value is calculated during hyperparameter optimization. It
 must be greater than zero and less than or equal to 1.
 end::dfas-downsample-factor[]
 
@@ -553,9 +553,9 @@ By default, early stoppping is enabled.
 end::dfas-early-stopping-enabled[]
 
 tag::dfas-eta-growth[]
-Advanced configuration option. Specifies the rate at which `eta` increases for 
-each new tree that is added to the forest. For example, a rate of 1.05 
-increases `eta` by 5% for each extra tree. By default, this value is calculated 
+Advanced configuration option. Specifies the rate at which `eta` increases for
+each new tree that is added to the forest. For example, a rate of 1.05
+increases `eta` by 5% for each extra tree. By default, this value is calculated
 during hyperparameter optimization. It must be between 0.5 and 2.
 end::dfas-eta-growth[]
 
@@ -565,16 +565,16 @@ candidate split.
 end::dfas-feature-bag-fraction[]
 
 tag::dfas-feature-processors[]
-Advanced configuration option. A collection of feature preprocessors that modify 
-one or more included fields. The analysis uses the resulting one or more 
-features instead of the original document field. However, these features are 
-ephemeral; they are not stored in the destination index. Multiple 
-`feature_processors` entries can refer to the same document fields. Automatic 
-categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs 
+Advanced configuration option. A collection of feature preprocessors that modify
+one or more included fields. The analysis uses the resulting one or more
+features instead of the original document field. However, these features are
+ephemeral; they are not stored in the destination index. Multiple
+`feature_processors` entries can refer to the same document fields. Automatic
+categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs
 for the fields that are unprocessed by a custom processor or that have
-categorical values. Use this property only if you want to override the automatic 
-feature encoding of the specified fields. Refer to 
-{ml-docs}/ml-feature-processors.html[{dfanalytics} feature processors] to learn 
+categorical values. Use this property only if you want to override the automatic
+feature encoding of the specified fields. Refer to
+{ml-docs}/ml-feature-processors.html[{dfanalytics} feature processors] to learn
 more.
 end::dfas-feature-processors[]
 
@@ -591,13 +591,13 @@ The configuration information necessary to perform frequency encoding.
 end::dfas-feature-processors-frequency[]
 
 tag::dfas-feature-processors-frequency-map[]
-The resulting frequency map for the field value. If the field value is missing 
+The resulting frequency map for the field value. If the field value is missing
 from the `frequency_map`, the resulting value is `0`.
 end::dfas-feature-processors-frequency-map[]
 
 tag::dfas-feature-processors-multi[]
-The configuration information necessary to perform multi encoding. It allows 
-multiple processors to be changed together. This way the output of a processor 
+The configuration information necessary to perform multi encoding. It allows
+multiple processors to be changed together. This way the output of a processor
 can then be passed to another as an input.
 end::dfas-feature-processors-multi[]
 
@@ -606,10 +606,10 @@ The ordered array of custom processors to execute. Must be more than 1.
 end::dfas-feature-processors-multi-proc[]
 
 tag::dfas-feature-processors-ngram[]
-The configuration information necessary to perform n-gram encoding. Features 
-created by this encoder have the following name format: 
-`<feature_prefix>.<ngram><string position>`. For example, if the 
-`feature_prefix` is `f`, the feature name for the second unigram in a string is 
+The configuration information necessary to perform n-gram encoding. Features
+created by this encoder have the following name format:
+`<feature_prefix>.<ngram><string position>`. For example, if the
+`feature_prefix` is `f`, the feature name for the second unigram in a string is
 `f.11`.
 end::dfas-feature-processors-ngram[]
 
@@ -622,17 +622,17 @@ The name of the text field to encode.
 end::dfas-feature-processors-ngram-field[]
 
 tag::dfas-feature-processors-ngram-length[]
-Specifies the length of the n-gram substring. Defaults to `50`. Must be greater 
+Specifies the length of the n-gram substring. Defaults to `50`. Must be greater
 than `0`.
 end::dfas-feature-processors-ngram-length[]
 
 tag::dfas-feature-processors-ngram-ngrams[]
-Specifies which n-grams to gather. It’s an array of integer values where the 
+Specifies which n-grams to gather. It’s an array of integer values where the
 minimum value is 1, and a maximum value is 5.
 end::dfas-feature-processors-ngram-ngrams[]
 
 tag::dfas-feature-processors-ngram-start[]
-Specifies the zero-indexed start of the n-gram substring. Negative values are 
+Specifies the zero-indexed start of the n-gram substring. Negative values are
 allowed for encoding n-grams of string suffixes. Defaults to `0`.
 end::dfas-feature-processors-ngram-start[]
 
@@ -686,19 +686,19 @@ decision tree when the tree is trained.
 end::dfas-num-splits[]
 
 tag::dfas-soft-limit[]
-Advanced configuration option. {ml-cap} uses loss guided tree growing, which 
-means that the decision trees grow where the regularized loss decreases most 
-quickly. This soft limit combines with the `soft_tree_depth_tolerance` to 
-penalize trees that exceed the specified depth; the regularized loss increases 
-quickly beyond this depth. By default, this value is calculated during 
+Advanced configuration option. {ml-cap} uses loss guided tree growing, which
+means that the decision trees grow where the regularized loss decreases most
+quickly. This soft limit combines with the `soft_tree_depth_tolerance` to
+penalize trees that exceed the specified depth; the regularized loss increases
+quickly beyond this depth. By default, this value is calculated during
 hyperparameter optimization. It must be greater than or equal to 0.
 end::dfas-soft-limit[]
 
 tag::dfas-soft-tolerance[]
-Advanced configuration option. This option controls how quickly the regularized 
-loss increases when the tree depth exceeds `soft_tree_depth_limit`. By default, 
-this value is calculated during hyperparameter optimization. It must be greater 
-than or equal to 0.01. 
+Advanced configuration option. This option controls how quickly the regularized
+loss increases when the tree depth exceeds `soft_tree_depth_limit`. By default,
+this value is calculated during hyperparameter optimization. It must be greater
+than or equal to 0.01.
 end::dfas-soft-tolerance[]
 
 tag::dfas-timestamp[]
@@ -744,7 +744,7 @@ end::empty-bucket-count[]
 tag::eta[]
 Advanced configuration option. The shrinkage applied to the weights. Smaller
 values result in larger forests which have a better generalization error.
-However, larger forests cause slower training. For more information about 
+However, larger forests cause slower training. For more information about
 shrinkage, refer to
 {wikipedia}/Gradient_boosting#Shrinkage[this wiki article].
 By default, this value is calculated during hyperparameter optimization. It must
@@ -833,10 +833,10 @@ end::function[]
 
 tag::gamma[]
 Advanced configuration option. Regularization parameter to prevent overfitting
-on the training data set. Multiplies a linear penalty associated with the size 
-of individual trees in the forest. A high gamma value causes training to prefer 
-small trees. A small gamma value results in larger individual trees and slower 
-training. By default, this value is calculated during hyperparameter 
+on the training data set. Multiplies a linear penalty associated with the size
+of individual trees in the forest. A high gamma value causes training to prefer
+small trees. A small gamma value results in larger individual trees and slower
+training. By default, this value is calculated during hyperparameter
 optimization. It must be a nonnegative value.
 end::gamma[]
 
@@ -849,7 +849,7 @@ An array of index names. Wildcards are supported. For example:
 `["it_ops_metrics", "server*"]`.
 +
 --
-NOTE: If any indices are in remote clusters then the {ml} nodes need to have the 
+NOTE: If any indices are in remote clusters then the {ml} nodes need to have the
 `remote_cluster_client` role.
 
 --
@@ -921,7 +921,7 @@ BERT-style tokenization is to be performed with the enclosed settings.
 end::inference-config-nlp-tokenization-bert[]
 
 tag::inference-config-nlp-tokenization-bert-do-lower-case[]
-Should the tokenization lower case the text sequence when building 
+Should the tokenization lower case the text sequence when building
 the tokens.
 end::inference-config-nlp-tokenization-bert-do-lower-case[]
 
@@ -930,7 +930,7 @@ Tokenize with special tokens. The tokens typically included in BERT-style tokeni
 +
 --
 * `[CLS]`: The first token of the sequence being classified.
-* `[SEP]`: Indicates sequence separation. 
+* `[SEP]`: Indicates sequence separation.
 --
 end::inference-config-nlp-tokenization-bert-with-special-tokens[]
 
@@ -998,6 +998,46 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
 used to train the model, which defaults to `<dependent_variable>_prediction`.
 end::inference-config-results-field-processor[]
 
+tag::inference-config-zero-shot-classification[]
+Configures a zero-shot classification task. Zero-shot classification allows for
+text classification to occur without pre-determined labels. At inference time,
+it is possible to adjust the labels to classify. This makes this type of model
+and task exceptionally flexible.
+
+If consistently classifying the same labels, it may be better to use a fine turned
+text classification model.
+end::inference-config-zero-shot-classification[]
+
+tag::inference-config-zero-shot-classification-classification-labels[]
+The classification labels used during the zero-shot classification. Classification
+labels must not be empty or null and only set at model creation. They must be all three
+of ["entailment", "neutral", "contradiction"].
+
+NOTE: This is NOT the same as `labels` which are the values that zero-shot is attempting to
+      classify.
+end::inference-config-zero-shot-classification-classification-labels[]
+
+tag::inference-config-zero-shot-classification-hypothesis-template[]
+This is the template used when tokenizing the sequences for classification.
+
+The labels replace the `{}` value in the text. The default value is:
+`This example is {}.`
+end::inference-config-zero-shot-classification-hypothesis-template[]
+
+tag::inference-config-zero-shot-classification-labels[]
+The labels to classify. Can be set at creation for default labels, and
+then updated during inference.
+end::inference-config-zero-shot-classification-labels[]
+
+tag::inference-config-zero-shot-classification-multi-label[]
+Indicates if more than one `true` label is possible given the input.
+
+This is useful when labeling text that could pertain to more than one of the
+input labels.
+
+Defaults to `false`.
+end::inference-config-zero-shot-classification-multi-label[]
+
 tag::inference-metadata-feature-importance-feature-name[]
 The feature for which this importance was calculated.
 end::inference-metadata-feature-importance-feature-name[]
@@ -1102,11 +1142,11 @@ end::job-id-datafeed[]
 tag::lambda[]
 Advanced configuration option. Regularization parameter to prevent overfitting
 on the training data set. Multiplies an L2 regularization term which applies to
-leaf weights of the individual trees in the forest. A high lambda value causes 
-training to favor small leaf weights. This behavior makes the prediction 
+leaf weights of the individual trees in the forest. A high lambda value causes
+training to favor small leaf weights. This behavior makes the prediction
 function smoother at the expense of potentially not being able to capture
 relevant relationships between the features and the {depvar}. A small lambda
-value results in large individual trees and slower training. By default, this 
+value results in large individual trees and slower training. By default, this
 value is calculated during hyperparameter optimization. It must be a nonnegative
 value.
 end::lambda[]
@@ -1151,13 +1191,13 @@ set.
 end::max-empty-searches[]
 
 tag::max-trees[]
-Advanced configuration option. Defines the maximum number of decision trees in 
-the forest. The maximum value is 2000. By default, this value is calculated 
+Advanced configuration option. Defines the maximum number of decision trees in
+the forest. The maximum value is 2000. By default, this value is calculated
 during hyperparameter optimization.
 end::max-trees[]
 
 tag::max-trees-trained-models[]
-The maximum number of decision trees in the forest. The maximum value is 2000. 
+The maximum number of decision trees in the forest. The maximum value is 2000.
 By default, this value is calculated during hyperparameter optimization.
 end::max-trees-trained-models[]
 
@@ -1222,7 +1262,7 @@ default value for jobs created in version 6.1 and later is `1024mb`. If the
 than `1024mb`, however, that value is used instead. The default value is
 relatively small to ensure that high resource usage is a conscious decision. If
 you have jobs that are expected to analyze high cardinality fields, you will
-likely need to use a higher value. 
+likely need to use a higher value.
 +
 If you specify a number instead of a string, the units are assumed to be MiB.
 Specifying a string is recommended for clarity. If you specify a byte size unit
@@ -1299,11 +1339,11 @@ Only the specified `terms` can be viewed when using the Single Metric Viewer.
 end::model-plot-config-terms[]
 
 tag::model-prune-window[]
-Advanced configuration option. 
-Affects the pruning of models that have not been updated for the given time 
-duration. The value must be set to a multiple of the `bucket_span`. If set too 
-low, important information may be removed from the model. Typically, set to 
-`30d` or longer. If not set, model pruning only occurs if the model memory 
+Advanced configuration option.
+Affects the pruning of models that have not been updated for the given time
+duration. The value must be set to a multiple of the `bucket_span`. If set too
+low, important information may be removed from the model. Typically, set to
+`30d` or longer. If not set, model pruning only occurs if the model memory
 status reaches the soft limit or the hard limit.
 end::model-prune-window[]
 
@@ -1391,10 +1431,10 @@ end::open-time[]
 
 tag::out-of-order-timestamp-count[]
 The number of input documents that have a timestamp chronologically
-preceding the start of the current anomaly detection bucket offset by 
-the latency window. This information is applicable only when you provide 
-data to the {anomaly-job} by using the <<ml-post-data,post data API>>. 
-These out of order documents are discarded, since jobs require time 
+preceding the start of the current anomaly detection bucket offset by
+the latency window. This information is applicable only when you provide
+data to the {anomaly-job} by using the <<ml-post-data,post data API>>.
+These out of order documents are discarded, since jobs require time
 series data to be in ascending chronological order.
 end::out-of-order-timestamp-count[]
 
@@ -1459,9 +1499,9 @@ number of {es} documents.
 end::processed-record-count[]
 
 tag::randomize-seed[]
-Defines the seed for the random generator that is used to pick training data. By 
-default, it is randomly generated. Set it to a specific value to use the same 
-training data each time you start a job (assuming other related parameters such 
+Defines the seed for the random generator that is used to pick training data. By
+default, it is randomly generated. Set it to a specific value to use the same
+training data each time you start a job (assuming other related parameters such
 as `source` and `analyzed_fields` are the same).
 end::randomize-seed[]
 

+ 28 - 21
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -11,19 +11,19 @@ import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.tasks.BaseTasksRequest;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ParseField;
-import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -31,6 +31,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 
 import static org.elasticsearch.action.ValidateActions.addValidationError;
 
@@ -45,11 +46,12 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
         super(NAME, InferTrainedModelDeploymentAction.Response::new);
     }
 
-    public static class Request extends BaseTasksRequest<Request> implements ToXContentObject {
+    public static class Request extends BaseTasksRequest<Request> {
 
         public static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id");
         public static final ParseField DOCS = new ParseField("docs");
         public static final ParseField TIMEOUT = new ParseField("timeout");
+        public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
 
         public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(10);
 
@@ -58,6 +60,11 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             PARSER.declareString(Request.Builder::setDeploymentId, DEPLOYMENT_ID);
             PARSER.declareObjectArray(Request.Builder::setDocs, (p, c) -> p.mapOrdered(), DOCS);
             PARSER.declareString(Request.Builder::setTimeout, TIMEOUT);
+            PARSER.declareNamedObject(
+                Request.Builder::setUpdate,
+                ((p, c, name) -> p.namedObject(InferenceConfigUpdate.class, name, c)),
+                INFERENCE_CONFIG
+            );
         }
 
         public static Request parseRequest(String deploymentId, XContentParser parser) {
@@ -70,16 +77,19 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
 
         private final String deploymentId;
         private final List<Map<String, Object>> docs;
+        private final InferenceConfigUpdate update;
 
-        public Request(String deploymentId, List<Map<String, Object>> docs) {
+        public Request(String deploymentId, InferenceConfigUpdate update, List<Map<String, Object>> docs) {
             this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, DEPLOYMENT_ID);
             this.docs = ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS);
+            this.update = update;
         }
 
         public Request(StreamInput in) throws IOException {
             super(in);
             deploymentId = in.readString();
             docs = Collections.unmodifiableList(in.readList(StreamInput::readMap));
+            update = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
         }
 
         public String getDeploymentId() {
@@ -90,6 +100,10 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return docs;
         }
 
+        public InferenceConfigUpdate getUpdate() {
+            return Optional.ofNullable(update).orElse(new EmptyConfigUpdate());
+        }
+
         @Override
         public TimeValue getTimeout() {
             TimeValue tv = super.getTimeout();
@@ -124,16 +138,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             super.writeTo(out);
             out.writeString(deploymentId);
             out.writeCollection(docs, StreamOutput::writeMap);
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
-            builder.startObject();
-            builder.field(DEPLOYMENT_ID.getPreferredName(), deploymentId);
-            builder.field(DOCS.getPreferredName(), docs);
-            builder.field(TIMEOUT.getPreferredName(), getTimeout().getStringRep());
-            builder.endObject();
-            return builder;
+            out.writeOptionalNamedWriteable(update);
         }
 
         @Override
@@ -148,17 +153,13 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             InferTrainedModelDeploymentAction.Request that = (InferTrainedModelDeploymentAction.Request) o;
             return Objects.equals(deploymentId, that.deploymentId)
                 && Objects.equals(docs, that.docs)
+                && Objects.equals(update, that.update)
                 && Objects.equals(getTimeout(), that.getTimeout());
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(deploymentId, docs, getTimeout());
-        }
-
-        @Override
-        public String toString() {
-            return Strings.toString(this);
+            return Objects.hash(deploymentId, update, docs, getTimeout());
         }
 
         public static class Builder {
@@ -166,6 +167,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             private String deploymentId;
             private List<Map<String, Object>> docs;
             private TimeValue timeout;
+            private InferenceConfigUpdate update;
 
             private Builder() {}
 
@@ -184,12 +186,17 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 return this;
             }
 
+            public Builder setUpdate(InferenceConfigUpdate update) {
+                this.update = update;
+                return this;
+            }
+
             private Builder setTimeout(String timeout) {
                 return setTimeout(TimeValue.parseTimeValue(timeout, TIMEOUT.getPreferredName()));
             }
 
             public Request build() {
-                Request request = new Request(deploymentId, docs);
+                Request request = new Request(deploymentId, update, docs);
                 if (timeout != null) {
                     request.setTimeout(timeout);
                 }

+ 18 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -52,6 +52,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfi
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
@@ -184,11 +186,23 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             new ParseField(TextEmbeddingConfig.NAME), TextEmbeddingConfig::fromXContentLenient));
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(TextEmbeddingConfig.NAME),
             TextEmbeddingConfig::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
+            new ParseField(ZeroShotClassificationConfig.NAME), ZeroShotClassificationConfig::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class,
+            new ParseField(ZeroShotClassificationConfig.NAME),
+            ZeroShotClassificationConfig::fromXContentStrict));
 
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
             ClassificationConfigUpdate::fromXContentStrict));
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
             RegressionConfigUpdate::fromXContentStrict));
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                InferenceConfigUpdate.class,
+                new ParseField(ZeroShotClassificationConfigUpdate.NAME),
+                ZeroShotClassificationConfigUpdate::fromXContentStrict
+            )
+        );
 
         // Inference models
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
@@ -288,6 +302,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             PassThroughConfig.NAME, PassThroughConfig::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
             TextEmbeddingConfig.NAME, TextEmbeddingConfig::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
+            ZeroShotClassificationConfig.NAME, ZeroShotClassificationConfig::new));
 
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));
@@ -297,6 +313,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             ResultsFieldUpdate.NAME, ResultsFieldUpdate::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             EmptyConfigUpdate.NAME, EmptyConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            ZeroShotClassificationConfigUpdate.NAME, ZeroShotClassificationConfigUpdate::new));
 
         // Location
         namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModelLocation.class,

+ 21 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java

@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.xcontent.ParseField;
+
+public abstract class NlpConfigUpdate implements InferenceConfigUpdate {
+
+    static ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
+
+    @Override
+    public InferenceConfig toConfig() {
+        throw new UnsupportedOperationException("cannot serialize to nodes before 7.8");
+    }
+
+}

+ 245 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java

@@ -0,0 +1,245 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+/**
+ * This builds out a 0-shot classification task.
+ *
+ * The 0-shot methodology assumed is MNLI optimized task. For further info see: https://arxiv.org/abs/1909.00161
+ *
+ */
+public class ZeroShotClassificationConfig implements NlpConfig {
+
+    public static final String NAME = "zero_shot_classification";
+    public static final ParseField HYPOTHESIS_TEMPLATE = new ParseField("hypothesis_template");
+    public static final ParseField MULTI_LABEL = new ParseField("multi_label");
+    public static final ParseField LABELS = new ParseField("labels");
+
+    public static ZeroShotClassificationConfig fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    public static ZeroShotClassificationConfig fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null);
+    }
+
+    private static final Set<String> REQUIRED_CLASSIFICATION_LABELS = new TreeSet<>(List.of("entailment", "neutral", "contradiction"));
+    private static final String DEFAULT_HYPOTHESIS_TEMPLATE = "This example is {}.";
+    private static final ConstructingObjectParser<ZeroShotClassificationConfig, Void> STRICT_PARSER = createParser(false);
+    private static final ConstructingObjectParser<ZeroShotClassificationConfig, Void> LENIENT_PARSER = createParser(true);
+
+    @SuppressWarnings({ "unchecked"})
+    private static ConstructingObjectParser<ZeroShotClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<ZeroShotClassificationConfig, Void> parser = new ConstructingObjectParser<>(
+            NAME,
+            ignoreUnknownFields,
+            a -> new ZeroShotClassificationConfig(
+                (List<String>)a[0],
+                (VocabularyConfig) a[1],
+                (Tokenization) a[2],
+                (String) a[3],
+                (Boolean) a[4],
+                (List<String>) a[5]
+            )
+        );
+        parser.declareStringArray(ConstructingObjectParser.constructorArg(), CLASSIFICATION_LABELS);
+        parser.declareObject(
+            ConstructingObjectParser.optionalConstructorArg(),
+            (p, c) -> {
+                if (ignoreUnknownFields == false) {
+                    throw ExceptionsHelper.badRequestException(
+                        "illegal setting [{}] on inference model creation",
+                        VOCABULARY.getPreferredName()
+                    );
+                }
+                return VocabularyConfig.fromXContentLenient(p);
+            },
+            VOCABULARY
+        );
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+                TOKENIZATION
+        );
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), HYPOTHESIS_TEMPLATE);
+        parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
+        parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
+        return parser;
+    }
+
+    private final VocabularyConfig vocabularyConfig;
+    private final Tokenization tokenization;
+    private final List<String> classificationLabels;
+    private final List<String> labels;
+    private final boolean isMultiLabel;
+    private final String hypothesisTemplate;
+
+    public ZeroShotClassificationConfig(
+        List<String> classificationLabels,
+        @Nullable VocabularyConfig vocabularyConfig,
+        @Nullable Tokenization tokenization,
+        @Nullable String hypothesisTemplate,
+        @Nullable Boolean isMultiLabel,
+        @Nullable List<String> labels
+    ) {
+        this.classificationLabels = ExceptionsHelper.requireNonNull(classificationLabels, CLASSIFICATION_LABELS);
+        if (this.classificationLabels.size() != 3) {
+            throw ExceptionsHelper.badRequestException(
+                "[{}] must contain exactly the three values {}",
+                CLASSIFICATION_LABELS.getPreferredName(),
+                REQUIRED_CLASSIFICATION_LABELS
+            );
+        }
+        List<String> badLabels = classificationLabels.stream()
+            .map(s -> s.toLowerCase(Locale.ROOT))
+            .filter(c -> REQUIRED_CLASSIFICATION_LABELS.contains(c) == false)
+            .collect(Collectors.toList());
+        if (badLabels.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException(
+                "[{}] must contain exactly the three values {}. Invalid labels {}",
+                CLASSIFICATION_LABELS.getPreferredName(),
+                REQUIRED_CLASSIFICATION_LABELS,
+                badLabels
+            );
+        }
+        this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
+            .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
+        this.isMultiLabel = isMultiLabel != null && isMultiLabel;
+        this.hypothesisTemplate = Optional.ofNullable(hypothesisTemplate).orElse(DEFAULT_HYPOTHESIS_TEMPLATE);
+        this.labels = labels;
+        if (labels != null && labels.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
+        }
+    }
+
+    public ZeroShotClassificationConfig(StreamInput in) throws IOException {
+        vocabularyConfig = new VocabularyConfig(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
+        classificationLabels = in.readStringList();
+        isMultiLabel = in.readBoolean();
+        hypothesisTemplate = in.readString();
+        labels = in.readOptionalStringList();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        vocabularyConfig.writeTo(out);
+        out.writeNamedWriteable(tokenization);
+        out.writeStringCollection(classificationLabels);
+        out.writeBoolean(isMultiLabel);
+        out.writeString(hypothesisTemplate);
+        out.writeOptionalStringCollection(labels);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
+        builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        builder.field(MULTI_LABEL.getPreferredName(), isMultiLabel);
+        builder.field(HYPOTHESIS_TEMPLATE.getPreferredName(), hypothesisTemplate);
+        if (labels != null) {
+            builder.field(LABELS.getPreferredName(), labels);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean isTargetTypeSupported(TargetType targetType) {
+        return false;
+    }
+
+    @Override
+    public Version getMinimalSupportedVersion() {
+        return Version.V_8_0_0;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == this) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        ZeroShotClassificationConfig that = (ZeroShotClassificationConfig) o;
+        return Objects.equals(vocabularyConfig, that.vocabularyConfig)
+            && Objects.equals(tokenization, that.tokenization)
+            && Objects.equals(isMultiLabel, that.isMultiLabel)
+            && Objects.equals(hypothesisTemplate, that.hypothesisTemplate)
+            && Objects.equals(labels, that.labels)
+            && Objects.equals(classificationLabels, that.classificationLabels);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, hypothesisTemplate, isMultiLabel, labels);
+    }
+
+    @Override
+    public VocabularyConfig getVocabularyConfig() {
+        return vocabularyConfig;
+    }
+
+    @Override
+    public Tokenization getTokenization() {
+        return tokenization;
+    }
+
+    public List<String> getClassificationLabels() {
+        return classificationLabels;
+    }
+
+    public boolean isMultiLabel() {
+        return isMultiLabel;
+    }
+
+    public String getHypothesisTemplate() {
+        return hypothesisTemplate;
+    }
+
+    public List<String> getLabels() {
+        return Optional.ofNullable(labels).orElse(List.of());
+    }
+
+    @Override
+    public boolean isAllocateOnly() {
+        return true;
+    }
+
+}

+ 201 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java

@@ -0,0 +1,201 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.LABELS;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.MULTI_LABEL;
+
+public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+
+    public static final String NAME = "zero_shot_classification";
+
+    public static ZeroShotClassificationConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    @SuppressWarnings({ "unchecked"})
+    public static ZeroShotClassificationConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        Boolean isMultiLabel = (Boolean)options.remove(MULTI_LABEL.getPreferredName());
+        List<String> labels = (List<String>)options.remove(LABELS.getPreferredName());
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
+        }
+        return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel);
+    }
+
+    @SuppressWarnings({ "unchecked"})
+    private static final ConstructingObjectParser<ZeroShotClassificationConfigUpdate, Void> STRICT_PARSER = new ConstructingObjectParser<>(
+        NAME,
+        a -> new ZeroShotClassificationConfigUpdate((List<String>)a[0], (Boolean) a[1])
+    );
+
+    static {
+        STRICT_PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
+        STRICT_PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
+    }
+
+    private final List<String> labels;
+    private final Boolean isMultiLabel;
+
+    public ZeroShotClassificationConfigUpdate(
+        @Nullable List<String> labels,
+        @Nullable Boolean isMultiLabel
+    ) {
+        this.labels = labels;
+        if (labels != null && labels.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
+        }
+        this.isMultiLabel = isMultiLabel;
+    }
+
+    public ZeroShotClassificationConfigUpdate(StreamInput in) throws IOException {
+        labels = in.readOptionalStringList();
+        isMultiLabel = in.readOptionalBoolean();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalStringCollection(labels);
+        out.writeOptionalBoolean(isMultiLabel);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (labels != null) {
+            builder.field(LABELS.getPreferredName(), labels);
+        }
+        if (isMultiLabel != null) {
+            builder.field(MULTI_LABEL.getPreferredName(), isMultiLabel);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (originalConfig instanceof ZeroShotClassificationConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a inference request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig)originalConfig;
+        if ((labels == null || labels.isEmpty()) && (zeroShotConfig.getLabels() == null || zeroShotConfig.getLabels().isEmpty())) {
+            throw ExceptionsHelper.badRequestException(
+                "stored configuration has no [{}] defined, supplied inference_config update must supply [{}]",
+                LABELS.getPreferredName(),
+                LABELS.getPreferredName()
+            );
+        }
+        if (isNoop(zeroShotConfig)) {
+            return originalConfig;
+        }
+        return new ZeroShotClassificationConfig(
+            zeroShotConfig.getClassificationLabels(),
+            zeroShotConfig.getVocabularyConfig(),
+            zeroShotConfig.getTokenization(),
+            zeroShotConfig.getHypothesisTemplate(),
+            Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()),
+            Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels())
+        );
+    }
+
+    boolean isNoop(ZeroShotClassificationConfig originalConfig) {
+        return (labels == null || labels.equals(originalConfig.getClassificationLabels()))
+            && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()));
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof ZeroShotClassificationConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return null;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder().setLabels(labels).setMultiLabel(isMultiLabel);
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == this) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        ZeroShotClassificationConfigUpdate that = (ZeroShotClassificationConfigUpdate) o;
+        return Objects.equals(isMultiLabel, that.isMultiLabel) && Objects.equals(labels, that.labels);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(labels, isMultiLabel);
+    }
+
+    public List<String> getLabels() {
+        return labels;
+    }
+
+    public static class Builder implements InferenceConfigUpdate.Builder<
+        ZeroShotClassificationConfigUpdate.Builder,
+        ZeroShotClassificationConfigUpdate
+        > {
+        private List<String> labels;
+        private Boolean isMultiLabel;
+
+        @Override
+        public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) {
+            throw new IllegalArgumentException();
+        }
+
+        public Builder setLabels(List<String> labels) {
+            this.labels = labels;
+            return this;
+        }
+
+        public Builder setMultiLabel(Boolean multiLabel) {
+            isMultiLabel = multiLabel;
+            return this;
+        }
+
+        public ZeroShotClassificationConfigUpdate build() {
+            return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel);
+        }
+    }
+}

+ 24 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java

@@ -7,19 +7,24 @@
 
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Tuple;
-import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;
 
-import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 
-public class InferTrainedModelDeploymentRequestsTests extends AbstractSerializingTestCase<InferTrainedModelDeploymentAction.Request> {
-    @Override
-    protected InferTrainedModelDeploymentAction.Request doParseInstance(XContentParser parser) throws IOException {
-        return InferTrainedModelDeploymentAction.Request.parseRequest(null, parser);
+public class InferTrainedModelDeploymentRequestsTests extends AbstractWireSerializingTestCase<InferTrainedModelDeploymentAction.Request> {
+
+
+    private static InferenceConfigUpdate randomInferenceConfigUpdate() {
+        return randomFrom(ZeroShotClassificationConfigUpdateTests.createRandom(), EmptyConfigUpdateTests.testInstance());
     }
 
     @Override
@@ -32,14 +37,24 @@ public class InferTrainedModelDeploymentRequestsTests extends AbstractSerializin
         List<Map<String, Object>> docs = randomList(5, () -> randomMap(1, 3,
             () -> Tuple.tuple(randomAlphaOfLength(7), randomAlphaOfLength(7))));
 
-        InferTrainedModelDeploymentAction.Request request =
-            new InferTrainedModelDeploymentAction.Request(randomAlphaOfLength(4), docs);
+        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
+            randomAlphaOfLength(4),
+            randomBoolean() ? null : randomInferenceConfigUpdate(),
+            docs
+        );
         if (randomBoolean()) {
             request.setTimeout(randomTimeValue());
         }
         return request;
     }
 
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
+        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+
     public void testTimeoutNotNull() {
         assertNotNull(createTestInstance().getTimeout());
     }

+ 61 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java

@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.function.Predicate;
+
+public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCase<ZeroShotClassificationConfig> {
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> field.isEmpty() == false;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfig doParseInstance(XContentParser parser) throws IOException {
+        return ZeroShotClassificationConfig.fromXContentLenient(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<ZeroShotClassificationConfig> instanceReader() {
+        return ZeroShotClassificationConfig::new;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfig createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected ZeroShotClassificationConfig mutateInstanceForVersion(ZeroShotClassificationConfig instance, Version version) {
+        return instance;
+    }
+
+    public static ZeroShotClassificationConfig createRandom() {
+        return new ZeroShotClassificationConfig(
+            randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomAlphaOfLength(10),
+            randomBoolean(),
+            randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10))
+        );
+    }
+}

+ 134 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

@@ -0,0 +1,134 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+
+public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItemTestCase<ZeroShotClassificationConfigUpdate> {
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return ZeroShotClassificationConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<ZeroShotClassificationConfigUpdate> instanceReader() {
+        return ZeroShotClassificationConfigUpdate::new;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfigUpdate createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected ZeroShotClassificationConfigUpdate mutateInstanceForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
+        return instance;
+    }
+
+    public void testFromMap() {
+        ZeroShotClassificationConfigUpdate expected = new ZeroShotClassificationConfigUpdate(List.of("foo", "bar"), false);
+        Map<String, Object> config = new HashMap<>(){{
+            put(ZeroShotClassificationConfig.LABELS.getPreferredName(), List.of("foo", "bar"));
+            put(ZeroShotClassificationConfig.MULTI_LABEL.getPreferredName(), false);
+        }};
+        assertThat(ZeroShotClassificationConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> ZeroShotClassificationConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+    public void testApply() {
+        ZeroShotClassificationConfig originalConfig = new ZeroShotClassificationConfig(
+            randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomAlphaOfLength(10),
+            randomBoolean(),
+            randomList(1, 5, () -> randomAlphaOfLength(10))
+        );
+
+        assertThat(originalConfig, equalTo(new ZeroShotClassificationConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(
+            new ZeroShotClassificationConfig(
+                originalConfig.getClassificationLabels(),
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                originalConfig.getHypothesisTemplate(),
+                originalConfig.isMultiLabel(),
+                List.of("foo", "bar")
+            ),
+            equalTo(
+                new ZeroShotClassificationConfigUpdate.Builder()
+                    .setLabels(List.of("foo", "bar")).build()
+                    .apply(originalConfig)
+            )
+        );
+        assertThat(
+            new ZeroShotClassificationConfig(
+                originalConfig.getClassificationLabels(),
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                originalConfig.getHypothesisTemplate(),
+                true,
+                originalConfig.getLabels()
+            ),
+            equalTo(
+                new ZeroShotClassificationConfigUpdate.Builder()
+                    .setMultiLabel(true).build()
+                    .apply(originalConfig)
+            )
+        );
+    }
+
+    public void testApplyWithEmptyLabelsInConfigAndUpdate() {
+        ZeroShotClassificationConfig originalConfig = new ZeroShotClassificationConfig(
+            randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomAlphaOfLength(10),
+            randomBoolean(),
+            null
+        );
+
+        Exception ex = expectThrows(Exception.class, () -> new ZeroShotClassificationConfigUpdate.Builder().build().apply(originalConfig));
+        assertThat(
+            ex.getMessage(),
+            containsString("stored configuration has no [labels] defined, supplied inference_config update must supply [labels]")
+        );
+    }
+
+    public static ZeroShotClassificationConfigUpdate createRandom() {
+        return new ZeroShotClassificationConfigUpdate(
+            randomBoolean() ? null : randomList(1,5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomBoolean()
+        );
+    }
+}

+ 4 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -80,7 +80,10 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
     @Override
     protected void taskOperation(InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task,
                                  ActionListener<InferTrainedModelDeploymentAction.Response> listener) {
-        task.infer(request.getDocs().get(0), request.getTimeout(),
+        task.infer(
+            request.getDocs().get(0),
+            request.getUpdate(),
+            request.getTimeout(),
             ActionListener.wrap(
                 pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)),
                 listener::onFailure)

+ 9 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
 import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -138,7 +139,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
                 // Always fail immediately and return an error
                 ex -> true);
         request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(
-            chainedTask -> inferSingleDocAgainstAllocatedModel(request.getModelId(), stringObjectMap, chainedTask)));
+            chainedTask -> inferSingleDocAgainstAllocatedModel(request.getModelId(), request.getUpdate(), stringObjectMap, chainedTask)));
 
         typedChainTaskExecutor.execute(ActionListener.wrap(
             inferenceResults -> listener.onResponse(responseBuilder.setInferenceResults(inferenceResults)
@@ -148,11 +149,16 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         ));
     }
 
-    private void inferSingleDocAgainstAllocatedModel(String modelId, Map<String, Object> doc, ActionListener<InferenceResults> listener) {
+    private void inferSingleDocAgainstAllocatedModel(
+        String modelId,
+        InferenceConfigUpdate inferenceConfigUpdate,
+        Map<String, Object> doc,
+        ActionListener<InferenceResults> listener
+    ) {
         executeAsyncWithOrigin(client,
             ML_ORIGIN,
             InferTrainedModelDeploymentAction.INSTANCE,
-            new InferTrainedModelDeploymentAction.Request(modelId, Collections.singletonList(doc)),
+            new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc)),
             ActionListener.wrap(
                 r -> listener.onResponse(r.getResults()),
                 e -> listener.onResponse(new WarningInferenceResults(e.getMessage()))

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
@@ -227,9 +228,12 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
         );
     }
 
-    public void infer(TrainedModelDeploymentTask task, Map<String, Object> doc, TimeValue timeout,
+    public void infer(TrainedModelDeploymentTask task,
+                      InferenceConfig config,
+                      Map<String, Object> doc,
+                      TimeValue timeout,
                       ActionListener<InferenceResults> listener) {
-        deploymentManager.infer(task, doc, timeout, listener);
+        deploymentManager.infer(task, config, doc, timeout, listener);
     }
 
     public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {

+ 16 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
@@ -131,6 +132,7 @@ public class DeploymentManager {
 
                 assert modelConfig.getInferenceConfig() instanceof NlpConfig;
                 NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
+                task.init(nlpConfig);
 
                 SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
                 executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
@@ -203,7 +205,9 @@ public class DeploymentManager {
     }
 
     public void infer(TrainedModelDeploymentTask task,
-                      Map<String, Object> doc, TimeValue timeout,
+                      InferenceConfig config,
+                      Map<String, Object> doc,
+                      TimeValue timeout,
                       ActionListener<InferenceResults> listener) {
         if (task.isStopped()) {
             listener.onFailure(
@@ -240,12 +244,20 @@ public class DeploymentManager {
                     List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
                     NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
                     processor.validateInputs(text);
-                    NlpTask.Request request = processor.getRequestBuilder().buildRequest(text, requestId);
+                    assert config instanceof NlpConfig;
+                    NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestId);
                     logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
                     PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.registerRequest(requestId);
                     processContext.process.get().writeInferenceRequest(request.processInput);
-                    waitForResult(processContext, pendingResult, request.tokenization, requestId, timeout, processor.getResultProcessor(),
-                        listener);
+                    waitForResult(
+                        processContext,
+                        pendingResult,
+                        request.tokenization,
+                        requestId,
+                        timeout,
+                        processor.getResultProcessor((NlpConfig) config),
+                        listener
+                    );
                 } catch (IOException e) {
                     logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e);
                     onFailure(ExceptionsHelper.serverError("error writing to process", e));

+ 26 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -18,6 +18,8 @@ import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
 
@@ -32,6 +34,7 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
     private final TrainedModelAllocationNodeService trainedModelAllocationNodeService;
     private volatile boolean stopped;
     private final SetOnce<String> stoppedReason = new SetOnce<>();
+    private final SetOnce<InferenceConfig> inferenceConfig = new SetOnce<>();
 
     public TrainedModelDeploymentTask(
         long id,
@@ -50,6 +53,10 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         );
     }
 
+    void init(InferenceConfig inferenceConfig) {
+        this.inferenceConfig.set(inferenceConfig);
+    }
+
     public String getModelId() {
         return params.getModelId();
     }
@@ -85,8 +92,25 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         stop(reason);
     }
 
-    public void infer(Map<String, Object> doc, TimeValue timeout, ActionListener<InferenceResults> listener) {
-        trainedModelAllocationNodeService.infer(this, doc, timeout, listener);
+    public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {
+        if (inferenceConfig.get() == null) {
+            listener.onFailure(
+                ExceptionsHelper.badRequestException("[{}] inference not possible against uninitialized model", params.getModelId())
+            );
+            return;
+        }
+        if (update.isSupported(inferenceConfig.get()) == false) {
+            listener.onFailure(
+                ExceptionsHelper.badRequestException(
+                    "[{}] inference not possible. Task is configured with [{}] but received update of type [{}]",
+                    params.getModelId(),
+                    inferenceConfig.get().getName(),
+                    update.getName()
+                )
+            );
+            return;
+        }
+        trainedModelAllocationNodeService.infer(this, update.apply(inferenceConfig.get()), doc, timeout, listener);
     }
 
     public Optional<ModelStats> modelStats() {

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -37,6 +37,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
@@ -358,7 +360,11 @@ public class InferenceProcessor extends AbstractProcessor {
             } else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
                 checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
                 return RegressionConfigUpdate.fromMap(valueMap);
-            } else {
+            } else if (configMap.containsKey(ZeroShotClassificationConfig.NAME)) {
+                checkSupportedVersion(new ZeroShotClassificationConfig(List.of("unused"), null, null, null, null, null));
+                return ZeroShotClassificationConfigUpdate.fromMap(valueMap);
+            }
+            else {
                 throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
                     configMap.keySet(),
                     Arrays.asList(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));

+ 13 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.io.IOException;
 import java.util.List;
+import java.util.stream.Collectors;
 
 public class BertRequestBuilder implements NlpTask.RequestBuilder {
 
@@ -37,7 +38,18 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
                 " token in its vocabulary");
         }
 
-        TokenizationResult tokenization = tokenizer.tokenize(inputs);
+        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
+            inputs.stream().map(tokenizer::tokenize).collect(Collectors.toList())
+        );
+        return buildRequest(tokenization, requestId);
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
+        if (tokenizer.getPadToken().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
+                " token in its vocabulary");
+        }
         return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId));
     }
 

+ 6 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -10,12 +10,14 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
@@ -49,23 +51,23 @@ public class FillMaskProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
     }
 
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return this::processResult;
     }
 
     InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
 
         if (tokenization.getTokenizations().isEmpty() ||
-            tokenization.getTokenizations().get(0).getTokens().isEmpty()) {
+            tokenization.getTokenizations().get(0).getTokens().length == 0) {
             return new FillMaskResults(Collections.emptyList());
         }
 
-        int maskTokenIndex = tokenization.getTokenizations().get(0).getTokens().indexOf(BertTokenizer.MASK_TOKEN);
+        int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf(BertTokenizer.MASK_TOKEN);
         // TODO - process all results in the batch
         double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
 

+ 8 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
@@ -124,12 +125,12 @@ public class NerProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
     }
 
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return new NerResultProcessor(iobMap);
     }
 
@@ -143,7 +144,7 @@ public class NerProcessor implements NlpTask.Processor {
         @Override
         public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
             if (tokenization.getTokenizations().isEmpty() ||
-                tokenization.getTokenizations().get(0).getTokens().isEmpty()) {
+                tokenization.getTokenizations().get(0).getTokens().length == 0) {
                 return new NerResults(Collections.emptyList());
             }
             // TODO - process all results in the batch
@@ -171,7 +172,7 @@ public class NerProcessor implements NlpTask.Processor {
                                            IobTag[] iobMap) {
             List<TaggedToken> taggedTokens = new ArrayList<>();
             int startTokenIndex = 0;
-            while (startTokenIndex < tokenization.getTokens().size()) {
+            while (startTokenIndex < tokenization.getTokens().length) {
                 int inputMapping = tokenization.getTokenMap()[startTokenIndex];
                 if (inputMapping < 0) {
                     // This token does not map to a token in the input (special tokens)
@@ -179,14 +180,14 @@ public class NerProcessor implements NlpTask.Processor {
                     continue;
                 }
                 int endTokenIndex = startTokenIndex;
-                StringBuilder word = new StringBuilder(tokenization.getTokens().get(startTokenIndex));
-                while (endTokenIndex < tokenization.getTokens().size() - 1
+                StringBuilder word = new StringBuilder(tokenization.getTokens()[startTokenIndex]);
+                while (endTokenIndex < tokenization.getTokens().length - 1
                     && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) {
                     endTokenIndex++;
                     // TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens.
                     // It is probably more correct to implement detokenization on the tokenizer
                     // that does reverse lookup based on token IDs.
-                    String endTokenWord = tokenization.getTokens().get(endTokenIndex).substring(2);
+                    String endTokenWord = tokenization.getTokens()[endTokenIndex].substring(2);
                     word.append(endTokenWord);
                 }
                 double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);

+ 4 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -56,6 +56,8 @@ public class NlpTask {
 
         Request buildRequest(List<String> inputs, String requestId) throws IOException;
 
+        Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException;
+
         static void writePaddedTokens(String fieldName,
                                       TokenizationResult tokenization,
                                       int padToken,
@@ -97,10 +99,6 @@ public class NlpTask {
         InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult);
     }
 
-    public interface ResultProcessorFactory {
-        ResultProcessor build(TokenizationResult tokenizationResult);
-    }
-
     public interface Processor {
         /**
          * Validate the task input string.
@@ -110,8 +108,8 @@ public class NlpTask {
          */
         void validateInputs(List<String> inputs);
 
-        RequestBuilder getRequestBuilder();
-        ResultProcessor getResultProcessor();
+        RequestBuilder getRequestBuilder(NlpConfig config);
+        ResultProcessor getResultProcessor(NlpConfig config);
     }
 
     public static String extractInput(TrainedModelInput input, Map<String, Object> doc) {

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -34,12 +35,12 @@ public class PassThroughProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
     }
 
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return PassThroughProcessor::processResult;
     }
 

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 
 import java.util.Locale;
@@ -48,6 +49,12 @@ public enum TaskType {
         public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new TextEmbeddingProcessor(tokenizer, (TextEmbeddingConfig) config);
         }
+    },
+    ZERO_SHOT_CLASSIFICATION {
+        @Override
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
+            return new ZeroShotClassificationProcessor(tokenizer, (ZeroShotClassificationConfig) config);
+        }
     };
 
     public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -73,12 +74,12 @@ public class TextClassificationProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
     }
 
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return this::processResult;
     }
 

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -33,12 +34,12 @@ public class TextEmbeddingProcessor implements NlpTask.Processor {
     }
 
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
     }
 
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return TextEmbeddingProcessor::processResult;
     }
 

+ 194 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -0,0 +1,194 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.logging.LoggerMessageFormat;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class ZeroShotClassificationProcessor implements NlpTask.Processor {
+
+    private final NlpTokenizer tokenizer;
+    private final int entailmentPos;
+    private final int contraPos;
+    private final String[] labels;
+    private final String hypothesisTemplate;
+    private final boolean isMultiLabel;
+
+    ZeroShotClassificationProcessor(NlpTokenizer tokenizer, ZeroShotClassificationConfig config) {
+        this.tokenizer = tokenizer;
+        List<String> lowerCased = config.getClassificationLabels()
+            .stream()
+            .map(s -> s.toLowerCase(Locale.ROOT))
+            .collect(Collectors.toList());
+        this.entailmentPos = lowerCased.indexOf("entailment");
+        this.contraPos = lowerCased.indexOf("contradiction");
+        if (entailmentPos == -1 || contraPos == -1) {
+            throw ExceptionsHelper.badRequestException(
+                "zero_shot_classification requires [entailment] and [contradiction] in classification_labels"
+            );
+        }
+        this.labels = Optional.ofNullable(config.getLabels()).orElse(List.of()).toArray(String[]::new);
+        this.hypothesisTemplate = config.getHypothesisTemplate();
+        this.isMultiLabel = config.isMultiLabel();
+    }
+
+    @Override
+    public void validateInputs(List<String> inputs) {
+        // nothing to validate
+    }
+
+    @Override
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
+        final String[] labels;
+        if (nlpConfig instanceof ZeroShotClassificationConfig) {
+            ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig) nlpConfig;
+            labels = zeroShotConfig.getLabels().toArray(new String[0]);
+        } else {
+            labels = this.labels;
+        }
+        if (this.labels == null || this.labels.length == 0) {
+            throw ExceptionsHelper.badRequestException("zero_shot_classification requires non-empty [labels]");
+        }
+        return new RequestBuilder(tokenizer, labels, hypothesisTemplate);
+    }
+
+    @Override
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
+        final String[] labels;
+        final boolean isMultiLabel;
+        if (nlpConfig instanceof ZeroShotClassificationConfig) {
+            ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig) nlpConfig;
+            labels = zeroShotConfig.getLabels().toArray(new String[0]);
+            isMultiLabel = zeroShotConfig.isMultiLabel();
+        } else {
+            labels = this.labels;
+            isMultiLabel = this.isMultiLabel;
+        }
+        return new ResultProcessor(entailmentPos, contraPos, labels, isMultiLabel);
+    }
+
+    static class RequestBuilder implements NlpTask.RequestBuilder {
+
+        private final NlpTokenizer tokenizer;
+        private final String[] labels;
+        private final String hypothesisTemplate;
+
+        RequestBuilder(NlpTokenizer tokenizer, String[] labels, String hypothesisTemplate) {
+            this.tokenizer = tokenizer;
+            this.labels = labels;
+            this.hypothesisTemplate = hypothesisTemplate;
+        }
+
+        @Override
+        public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
+            if (inputs.size() > 1) {
+                throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time");
+            }
+            List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
+            for (String label : labels) {
+                tokenizations.add(tokenizer.tokenize(inputs.get(0), LoggerMessageFormat.format(null, hypothesisTemplate, label)));
+            }
+            TokenizationResult result = tokenizer.buildTokenizationResult(tokenizations);
+            return buildRequest(result, requestId);
+        }
+
+        @Override
+        public NlpTask.Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException {
+            return tokenizer.requestBuilder().buildRequest(tokenizationResult, requestId);
+        }
+    }
+
+    static class ResultProcessor implements NlpTask.ResultProcessor {
+        private final int entailmentPos;
+        private final int contraPos;
+        private final String[] labels;
+        private final boolean isMultiLabel;
+
+        ResultProcessor(int entailmentPos, int contraPos, String[] labels, boolean isMultiLabel) {
+            this.entailmentPos = entailmentPos;
+            this.contraPos = contraPos;
+            this.labels = labels;
+            this.isMultiLabel = isMultiLabel;
+        }
+
+        @Override
+        public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+            if (pyTorchResult.getInferenceResult().length < 1) {
+                return new WarningInferenceResults("Zero shot classification result has no data");
+            }
+            // TODO only the first entry in the batch result is verified and
+            // checked. Implement for all in batch
+            if (pyTorchResult.getInferenceResult()[0].length != labels.length) {
+                return new WarningInferenceResults(
+                    "Expected exactly [{}] values in zero shot classification result; got [{}]",
+                    labels.length,
+                    pyTorchResult.getInferenceResult().length
+                );
+            }
+            final double[] normalizedScores;
+            if (isMultiLabel) {
+                normalizedScores = new double[pyTorchResult.getInferenceResult()[0].length];
+                int v = 0;
+                for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
+                    if (vals.length != 3) {
+                        return new WarningInferenceResults(
+                            "Expected exactly [{}] values in inner zero shot classification result; got [{}]",
+                            3,
+                            vals.length
+                        );
+                    }
+                    // assume entailment is `0`, softmax between entailment and contradiction
+                    normalizedScores[v++] = NlpHelpers.convertToProbabilitiesBySoftMax(
+                        new double[]{vals[entailmentPos], vals[contraPos]}
+                    )[0];
+                }
+            } else {
+                double[] entailmentScores = new double[pyTorchResult.getInferenceResult()[0].length];
+                int v = 0;
+                for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
+                    if (vals.length != 3) {
+                        return new WarningInferenceResults(
+                            "Expected exactly [{}] values in inner zero shot classification result; got [{}]",
+                            3,
+                            vals.length
+                        );
+                    }
+                    entailmentScores[v++] = vals[entailmentPos];
+                }
+                normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(entailmentScores);
+            }
+
+            return new TextClassificationResults(
+                IntStream.range(0, normalizedScores.length)
+                    .mapToObj(i -> new TopClassEntry(labels[i], normalizedScores[i]))
+                    // Put the highest scoring class first
+                    .sorted(Comparator.comparing(TopClassEntry::getProbability).reversed())
+                    .limit(labels.length)
+                    .collect(Collectors.toList())
+            );
+        }
+    }
+}

+ 110 - 52
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
@@ -76,74 +77,63 @@ public class BertTokenizer implements NlpTokenizer {
         this.requestBuilder = requestBuilderFactory.apply(this);
     }
 
+    @Override
+    public OptionalInt getPadToken() {
+        Integer pad = vocab.get(PAD_TOKEN);
+        if (pad != null) {
+            return OptionalInt.of(pad);
+        } else {
+            return OptionalInt.empty();
+        }
+    }
+
+    @Override
+    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations) {
+        TokenizationResult tokenizationResult = new TokenizationResult(originalVocab);
+        for (TokenizationResult.Tokenization tokenization : tokenizations) {
+            tokenizationResult.addTokenization(tokenization);
+        }
+        return tokenizationResult;
+    }
+
     /**
-     * Tokenize the list of inputs according to the basic tokenization
+     * Tokenize the input according to the basic tokenization
      * options then perform Word Piece tokenization with the given vocabulary.
      *
      * The result is the Word Piece tokens, a map of the Word Piece
      * token position to the position of the token in the source for
      * each input string grouped into a {@link Tokenization}.
      *
-     * @param text Text to tokenize
+     * @param seq Text to tokenize
      * @return A {@link Tokenization}
      */
     @Override
-    public TokenizationResult tokenize(List<String> text) {
-        TokenizationResult tokenization = new TokenizationResult(originalVocab);
-
-        for (String input: text) {
-            addTokenization(tokenization, input);
-        }
-        return tokenization;
-    }
-
-
-    private void addTokenization(TokenizationResult tokenization, String text) {
-        BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
-
-        List<String> delineatedTokens = basicTokenizer.tokenize(text);
-        List<WordPieceTokenizer.TokenAndId> wordPieceTokens = new ArrayList<>();
-        List<Integer> tokenPositionMap = new ArrayList<>();
-        if (withSpecialTokens) {
-            // insert the first token to simplify the loop counter logic later
-            tokenPositionMap.add(SPECIAL_TOKEN_POSITION);
-        }
-
-        for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) {
-            String token = delineatedTokens.get(sourceIndex);
-            if (neverSplit.contains(token)) {
-                wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN))));
-                tokenPositionMap.add(sourceIndex);
-            } else {
-                List<WordPieceTokenizer.TokenAndId> tokens = wordPieceTokenizer.tokenize(token);
-                for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) {
-                    tokenPositionMap.add(sourceIndex);
-                }
-                wordPieceTokens.addAll(tokens);
-            }
-        }
-
+    public TokenizationResult.Tokenization tokenize(String seq) {
+        var innerResult = innerTokenize(seq);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokens = innerResult.v1();
+        List<Integer> tokenPositionMap = innerResult.v2();
         int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
-        List<String> tokens = new ArrayList<>(numTokens);
-        int [] tokenIds = new int[numTokens];
-        int [] tokenMap = new int[numTokens];
+        String[] tokens = new String[numTokens];
+        int[] tokenIds = new int[numTokens];
+        int[] tokenMap = new int[numTokens];
 
         if (withSpecialTokens) {
-            tokens.add(CLASS_TOKEN);
+            tokens[0] = CLASS_TOKEN;
             tokenIds[0] = vocab.get(CLASS_TOKEN);
             tokenMap[0] = SPECIAL_TOKEN_POSITION;
         }
 
         int i = withSpecialTokens ? 1 : 0;
+        final int decrementHandler = withSpecialTokens ? 1 : 0;
         for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) {
-            tokens.add(tokenAndId.getToken());
+            tokens[i] = tokenAndId.getToken();
             tokenIds[i] = tokenAndId.getId();
-            tokenMap[i] = tokenPositionMap.get(i);
+            tokenMap[i] = tokenPositionMap.get(i-decrementHandler);
             i++;
         }
 
         if (withSpecialTokens) {
-            tokens.add(SEPARATOR_TOKEN);
+            tokens[i] = SEPARATOR_TOKEN;
             tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
             tokenMap[i] = SPECIAL_TOKEN_POSITION;
         }
@@ -155,18 +145,86 @@ public class BertTokenizer implements NlpTokenizer {
                 maxSequenceLength
             );
         }
-
-        tokenization.addTokenization(text, tokens, tokenIds, tokenMap);
+        return new TokenizationResult.Tokenization(seq, tokens, tokenIds, tokenMap);
     }
 
     @Override
-    public OptionalInt getPadToken() {
-        Integer pad = vocab.get(PAD_TOKEN);
-        if (pad != null) {
-            return OptionalInt.of(pad);
-        } else {
-            return OptionalInt.empty();
+    public TokenizationResult.Tokenization tokenize(String seq1, String seq2) {
+        var innerResult = innerTokenize(seq1);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokenSeq1s = innerResult.v1();
+        List<Integer> tokenPositionMapSeq1 = innerResult.v2();
+        innerResult = innerTokenize(seq2);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokenSeq2s = innerResult.v1();
+        List<Integer> tokenPositionMapSeq2 = innerResult.v2();
+        if (withSpecialTokens == false)  {
+            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
+        }
+        // [CLS] seq1 [SEP] seq2 [SEP]
+        int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3;
+        String[] tokens = new String[numTokens];
+        int[] tokenIds = new int[numTokens];
+        int[] tokenMap = new int[numTokens];
+
+        tokens[0] = CLASS_TOKEN;
+        tokenIds[0] = vocab.get(CLASS_TOKEN);
+        tokenMap[0] = SPECIAL_TOKEN_POSITION;
+
+        int i = 1;
+        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq1s) {
+            tokens[i] = tokenAndId.getToken();
+            tokenIds[i] = tokenAndId.getId();
+            tokenMap[i] = tokenPositionMapSeq1.get(i - 1);
+            i++;
+        }
+        tokens[i] = SEPARATOR_TOKEN;
+        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
+        tokenMap[i] = SPECIAL_TOKEN_POSITION;
+        ++i;
+
+        int j = 0;
+        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq2s) {
+            tokens[i] = tokenAndId.getToken();
+            tokenIds[i] = tokenAndId.getId();
+            tokenMap[i] = tokenPositionMapSeq2.get(j);
+            i++;
+            j++;
+        }
+
+        tokens[i] = SEPARATOR_TOKEN;
+        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
+        tokenMap[i] = SPECIAL_TOKEN_POSITION;
+
+        // TODO handle seq1 truncation
+        if (tokenIds.length > maxSequenceLength) {
+            throw ExceptionsHelper.badRequestException(
+                "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
+                tokenIds.length,
+                maxSequenceLength
+            );
+        }
+        return new TokenizationResult.Tokenization(seq1 + seq2, tokens, tokenIds, tokenMap);
+    }
+
+    private Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerTokenize(String seq) {
+        BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
+        List<String> delineatedTokens = basicTokenizer.tokenize(seq);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokens = new ArrayList<>();
+        List<Integer> tokenPositionMap = new ArrayList<>();
+
+        for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) {
+            String token = delineatedTokens.get(sourceIndex);
+            if (neverSplit.contains(token)) {
+                wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN))));
+                tokenPositionMap.add(sourceIndex);
+            } else {
+                List<WordPieceTokenizer.TokenAndId> tokens = wordPieceTokenizer.tokenize(token);
+                for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) {
+                    tokenPositionMap.add(sourceIndex);
+                }
+                wordPieceTokens.addAll(tokens);
+            }
         }
+        return Tuple.tuple(wordPieceTokens, tokenPositionMap);
     }
 
     @Override

+ 5 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -22,7 +22,11 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.V
 
 public interface NlpTokenizer {
 
-    TokenizationResult tokenize(List<String> text);
+    TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations);
+
+    TokenizationResult.Tokenization tokenize(String seq);
+
+    TokenizationResult.Tokenization tokenize(String seq1, String seq2);
 
     NlpTask.RequestBuilder requestBuilder();
 

+ 13 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -29,26 +29,31 @@ public class TokenizationResult {
         return tokenizations;
     }
 
-    public void addTokenization(String input, List<String> tokens, int[] tokenIds, int[] tokenMap) {
+    public void addTokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) {
         maxLength = Math.max(maxLength, tokenIds.length);
         tokenizations.add(new Tokenization(input, tokens, tokenIds, tokenMap));
     }
 
+    public void addTokenization(Tokenization tokenization) {
+        maxLength = Math.max(maxLength, tokenization.tokenIds.length);
+        tokenizations.add(tokenization);
+    }
+
     public int getLongestSequenceLength() {
         return maxLength;
     }
 
     public static class Tokenization {
 
-        String input;
-        private final List<String> tokens;
+        private final String inputSeqs;
+        private final String[] tokens;
         private final int[] tokenIds;
         private final int[] tokenMap;
 
-        public Tokenization(String input, List<String> tokens, int[] tokenIds, int[] tokenMap) {
-            assert tokens.size() == tokenIds.length;
+        public Tokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) {
+            assert tokens.length == tokenIds.length;
             assert tokenIds.length == tokenMap.length;
-            this.input = input;
+            this.inputSeqs = input;
             this.tokens = tokens;
             this.tokenIds = tokenIds;
             this.tokenMap = tokenMap;
@@ -59,7 +64,7 @@ public class TokenizationResult {
          *
          * @return A list of tokens
          */
-        public List<String> getTokens() {
+        public String[] getTokens() {
             return tokens;
         }
 
@@ -84,7 +89,7 @@ public class TokenizationResult {
         }
 
         public String getInput() {
-            return input;
+            return inputSeqs;
         }
     }
 }

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -41,7 +41,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris";
 
         List<String> vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France");
-        List<String> tokens = Arrays.asList(input.split(" "));
+        String[] tokens = input.split(" ");
         int[] tokenMap = new int[] {0, 1, 2, 3, 4, 5};
         int[] tokenIds = new int[] {0, 1, 2, 3, 4, 5};
 
@@ -68,7 +68,7 @@ public class FillMaskProcessorTests extends ESTestCase {
 
     public void testProcessResults_GivenMissingTokens() {
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
-        tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {});
+        tokenization.addTokenization("", new String[]{}, new int[] {}, new int[] {});
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -222,6 +222,6 @@ public class NerProcessorTests extends ESTestCase {
             vocab,
             new BertTokenization(true, false, null)
         ).setDoLowerCase(true).setWithSpecialTokens(false).build();
-        return tokenizer.tokenize(List.of(input));
+        return tokenizer.buildTokenizationResult(List.of(tokenizer.tokenize(input)));
     }
 }

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -65,7 +65,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
         TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
 
-        NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1");
+        NlpTask.Request request = processor.getRequestBuilder(config).buildRequest(List.of("Elasticsearch fun"), "request1");
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 

+ 78 - 34
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -15,7 +15,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
-import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.arrayContaining;
 import static org.hamcrest.Matchers.hasSize;
 
 public class BertTokenizerTests extends ESTestCase {
@@ -26,9 +26,8 @@ public class BertTokenizerTests extends ESTestCase {
             new BertTokenization(null, false, null)
         ).build();
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun"));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun"));
         assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
     }
@@ -39,9 +38,8 @@ public class BertTokenizerTests extends ESTestCase {
             Tokenization.createDefault()
         ).build();
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("elasticsearch fun"));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]"));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("elasticsearch fun");
+        assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "elastic", "##search", "fun", "[SEP]"));
         assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
     }
@@ -56,9 +54,8 @@ public class BertTokenizerTests extends ESTestCase {
          .setWithSpecialTokens(false)
          .build();
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch " + specialToken + " fun"));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun"));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", specialToken, "fun"));
         assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
     }
@@ -72,15 +69,13 @@ public class BertTokenizerTests extends ESTestCase {
              .setWithSpecialTokens(false)
              .build();
 
-            TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
-            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-            assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun"));
+            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+            assertThat(tokenization.getTokens(), arrayContaining(BertTokenizer.UNKNOWN_TOKEN, "fun"));
             assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
             assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
 
-            tr = tokenizer.tokenize(List.of("elasticsearch fun"));
-            tokenization = tr.getTokenizations().get(0);
-            assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
+            tokenization = tokenizer.tokenize("elasticsearch fun");
+            assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun"));
         }
 
         {
@@ -89,9 +84,8 @@ public class BertTokenizerTests extends ESTestCase {
                 .setWithSpecialTokens(false)
                 .build();
 
-            TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
-            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-            assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
+            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+            assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun"));
         }
     }
 
@@ -101,15 +95,13 @@ public class BertTokenizerTests extends ESTestCase {
             Tokenization.createDefault()
         ).setWithSpecialTokens(false).build();
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch, fun."));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "."));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());
 
-        tr = tokenizer.tokenize(List.of("Elasticsearch, fun [MASK]."));
-        tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "[MASK]", "."));
+        tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "[MASK]", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
     }
@@ -124,31 +116,83 @@ public class BertTokenizerTests extends ESTestCase {
             new BertTokenization(null, false, null)
         ).build();
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch",
-            "my little red car",
-            "Godzilla day",
-            "Godzilla Pancake red car day"
-            ));
+        TokenizationResult tr = tokenizer.buildTokenizationResult(
+            List.of(
+                tokenizer.tokenize("Elasticsearch"),
+                tokenizer.tokenize("my little red car"),
+                tokenizer.tokenize("Godzilla day"),
+                tokenizer.tokenize("Godzilla Pancake red car day")
+            )
+        );
         assertThat(tr.getTokenizations(), hasSize(4));
 
         TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search"));
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search"));
         assertArrayEquals(new int[] {0, 1}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0}, tokenization.getTokenMap());
 
         tokenization = tr.getTokenizations().get(1);
-        assertThat(tokenization.getTokens(), contains("my", "little", "red", "car"));
+        assertThat(tokenization.getTokens(), arrayContaining("my", "little", "red", "car"));
         assertArrayEquals(new int[] {5, 6, 7, 8}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1, 2, 3}, tokenization.getTokenMap());
 
         tokenization = tr.getTokenizations().get(2);
-        assertThat(tokenization.getTokens(), contains("God", "##zilla", "day"));
+        assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "day"));
         assertArrayEquals(new int[] {9, 10, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
 
         tokenization = tr.getTokenizations().get(3);
-        assertThat(tokenization.getTokens(), contains("God", "##zilla", "Pancake", "red", "car", "day"));
+        assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "Pancake", "red", "car", "day"));
         assertArrayEquals(new int[] {9, 10, 3, 7, 8, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
     }
+
+    public void testMultiSeqTokenization() {
+        List<String> vocab = List.of(
+            "Elastic",
+            "##search",
+            "is",
+            "fun",
+            "my",
+            "little",
+            "red",
+            "car",
+            "God",
+            "##zilla",
+            BertTokenizer.CLASS_TOKEN,
+            BertTokenizer.SEPARATOR_TOKEN
+        );
+        BertTokenizer tokenizer = BertTokenizer.builder(vocab, Tokenization.createDefault())
+            .setDoLowerCase(false)
+            .setWithSpecialTokens(true)
+            .build();
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car");
+        assertThat(
+            tokenization.getTokens(),
+            arrayContaining(
+                BertTokenizer.CLASS_TOKEN,
+                "Elastic",
+                "##search",
+                "is",
+                "fun",
+                BertTokenizer.SEPARATOR_TOKEN,
+                "God",
+                "##zilla",
+                "my",
+                "little",
+                "red",
+                "car",
+                BertTokenizer.SEPARATOR_TOKEN
+            )
+        );
+        assertArrayEquals(new int[] { 10, 0, 1, 2, 3, 11, 8, 9, 4, 5, 6, 7, 11 }, tokenization.getTokenIds());
+    }
+
+    public void testMultiSeqRequiresSpecialTokens() {
+        BertTokenizer tokenizer = BertTokenizer.builder(List.of("foo"), Tokenization.createDefault())
+            .setDoLowerCase(false)
+            .setWithSpecialTokens(false)
+            .build();
+        expectThrows(Exception.class, () -> tokenizer.tokenize("foo", "foo"));
+    }
 }