Browse Source

[ML] add zero_shot_classification task for BERT nlp models (#77799)

Zero-Shot classification allows for text classification tasks without a pre-trained collection of target labels.

This is achieved through models trained on the Multi-Genre Natural Language Inference (MNLI) dataset. This dataset pairs  text sequences with "entailment" clauses. An example could be:

"Throughout all of history, man kind has shown itself resourceful, yet astoundingly short-sighted" could have been paired with the entailment clauses: ["This example is history", "This example is sociology"...]. 

This training set combined with the attention and semantic knowledge in modern day NLP models (BERT, BART, etc.) affords a powerful tool for ad-hoc text classification.

See https://arxiv.org/abs/1909.00161 for a deeper explanation of the MNLI training and how zero-shot works. 

The zeroshot classification task is configured as follows:
```js
{
   // <snip> model configuration </snip>
  "inference_config" : {
    "zero_shot_classification": {
      "classification_labels": ["entailment", "neutral", "contradiction"], // <1>
      "labels": ["sad", "glad", "mad", "rad"], // <2>
      "multi_label": false, // <3>
      "hypothesis_template": "This example is {}.", // <4>
      "tokenization": { /*<snip> tokenization configuration </snip>*/}
    }
  }
}
```
* <1> For all zero_shot models, there returns 3 particular labels when classification the target sequence. "entailment" is the positive case, "neutral" the case where the sequence isn't positive or negative, and "contradiction" is the negative case
* <2> This is an optional parameter for the default zero_shot labels to attempt to classify
* <3> When returning the probabilities, should the results assume there is only one true label or multiple true labels
* <4> The hypothesis template when tokenizing the labels. When combining with `sad` the sequence looks like `This example is sad.`

For inference in a pipeline one may provide label updates:
```js
{
  //<snip> pipeline definition </snip>
  "processors": [
    //<snip> other processors </snip>
    {
      "inference": {
        // <snip> general configuration </snip>
        "inference_config": {
          "zero_shot_classification": {
             "labels": ["humanities", "science", "mathematics", "technology"], // <1>
             "multi_label": true // <2>
          }
        }
      }
    }
    //<snip> other processors </snip>
  ]
}
```
* <1> The `labels` we care about, these replace the default ones if they exist. 
* <2> Should the results allow multiple true labels

Similarly one may provide label changes against the `_infer` endpoint
```js
{
   "docs":[{ "text_field": "This is a very happy person"}],
   "inference_config":{"zero_shot_classification":{"labels": ["glad", "sad", "bad", "rad"], "multi_label": false}}
}
```
Benjamin Trent 4 years ago
parent
commit
408489310c
33 changed files with 1493 additions and 255 deletions
  1. 80 18
      docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc
  2. 54 3
      docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc
  3. 108 68
      docs/reference/ml/ml-shared.asciidoc
  4. 28 21
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java
  5. 18 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  6. 21 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java
  7. 245 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java
  8. 201 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java
  9. 24 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java
  10. 61 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java
  11. 134 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java
  12. 4 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java
  13. 9 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  14. 6 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java
  15. 16 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  16. 26 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java
  17. 7 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  18. 13 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java
  19. 6 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  20. 8 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  21. 4 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  22. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  23. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java
  24. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  25. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
  26. 194 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
  27. 110 52
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  28. 5 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java
  29. 13 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java
  30. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  31. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  32. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  33. 78 34
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

+ 80 - 18
docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc

@@ -27,7 +27,7 @@ Retrieves configuration information for a trained model.
 [[ml-get-trained-models-prereq]]
 [[ml-get-trained-models-prereq]]
 == {api-prereq-title}
 == {api-prereq-title}
 
 
-Requires the `monitor_ml` cluster privilege. This privilege is included in the 
+Requires the `monitor_ml` cluster privilege. This privilege is included in the
 `machine_learning_user` built-in role.
 `machine_learning_user` built-in role.
 
 
 
 
@@ -71,9 +71,9 @@ default value is empty, indicating no optional fields are included. Valid
 options are:
 options are:
  - `definition`: Includes the model definition.
  - `definition`: Includes the model definition.
  - `feature_importance_baseline`: Includes the baseline for {feat-imp} values.
  - `feature_importance_baseline`: Includes the baseline for {feat-imp} values.
- - `hyperparameters`: Includes the information about hyperparameters used to 
-    train the model. This information consists of the value, the absolute and 
-    relative importance of the hyperparameter as well as an indicator of whether 
+ - `hyperparameters`: Includes the information about hyperparameters used to
+    train the model. This information consists of the value, the absolute and
+    relative importance of the hyperparameter as well as an indicator of whether
     it was specified by the user or tuned during hyperparameter optimization.
     it was specified by the user or tuned during hyperparameter optimization.
  - `total_feature_importance`: Includes the total {feat-imp} for the training
  - `total_feature_importance`: Includes the total {feat-imp} for the training
    data set.
    data set.
@@ -222,8 +222,8 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-ner]
 [%collapsible%open]
 [%collapsible%open]
 ======
 ======
 `classification_labels`::::
 `classification_labels`::::
-(Optional, string) 
-An array of classification labels. NER supports only 
+(Optional, string)
+An array of classification labels. NER supports only
 Inside-Outside-Beginning labels (IOB) and only persons, organizations, locations,
 Inside-Outside-Beginning labels (IOB) and only persons, organizations, locations,
 and miscellaneous. For example:
 and miscellaneous. For example:
 `["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]`.
 `["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]`.
@@ -338,7 +338,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-classific
 [%collapsible%open]
 [%collapsible%open]
 ======
 ======
 `classification_labels`::::
 `classification_labels`::::
-(Optional, string) 
+(Optional, string)
 An array of classification labels.
 An array of classification labels.
 
 
 `num_top_classes`::::
 `num_top_classes`::::
@@ -414,6 +414,68 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, integer)
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
 
 
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
+========
+=======
+`vocabulary`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-vocabulary]
++
+.Properties of vocabulary
+[%collapsible%open]
+=======
+`index`::::
+(Required, string)
+The index where the vocabulary is stored.
+=======
+======
+`zero_shot_classification`::::
+(Object, optional)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification]
++
+.Properties of zero_shot_classification inference
+[%collapsible%open]
+======
+`classification_labels`::::
+(Required, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-classification-labels]
+
+`hypothesis_template`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-hypothesis-template]
+
+`labels`::::
+(Optional, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-labels]
+
+`multi_label`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-multi-label]
+
+`tokenization`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
++
+.Properties of tokenization
+[%collapsible%open]
+=======
+`bert`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
++
+.Properties of bert
+[%collapsible%open]
+========
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
 `with_special_tokens`::::
 `with_special_tokens`::::
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
@@ -456,7 +518,7 @@ provided.
 =====
 =====
 `index`:::
 `index`:::
 (Required, object)
 (Required, object)
-Indicates that the model definition is stored in an index. It is required to be empty as 
+Indicates that the model definition is stored in an index. It is required to be empty as
 the index for storing model definitions is configured automatically.
 the index for storing model definitions is configured automatically.
 =====
 =====
 // End location
 // End location
@@ -480,7 +542,7 @@ it is a single value. For {classanalysis}, there is a value for each class.
 
 
 `hyperparameters`:::
 `hyperparameters`:::
 (array)
 (array)
-List of the available hyperparameters optimized during the 
+List of the available hyperparameters optimized during the
 `fine_parameter_tuning` phase as well as specified by the user.
 `fine_parameter_tuning` phase as well as specified by the user.
 +
 +
 .Properties of hyperparameters
 .Properties of hyperparameters
@@ -488,10 +550,10 @@ List of the available hyperparameters optimized during the
 ======
 ======
 `absolute_importance`::::
 `absolute_importance`::::
 (double)
 (double)
-A positive number showing how much the parameter influences the variation of the 
-{ml-docs}/dfa-regression-lossfunction.html[loss function]. For 
-hyperparameters with values that are not specified by the user but tuned during 
-hyperparameter optimization. 
+A positive number showing how much the parameter influences the variation of the
+{ml-docs}/dfa-regression-lossfunction.html[loss function]. For
+hyperparameters with values that are not specified by the user but tuned during
+hyperparameter optimization.
 
 
 `max_trees`::::
 `max_trees`::::
 (integer)
 (integer)
@@ -503,14 +565,14 @@ Name of the hyperparameter.
 
 
 `relative_importance`::::
 `relative_importance`::::
 (double)
 (double)
-A number between 0 and 1 showing the proportion of influence on the variation of 
-the loss function among all tuned hyperparameters. For hyperparameters with 
-values that are not specified by the user but tuned during hyperparameter 
+A number between 0 and 1 showing the proportion of influence on the variation of
+the loss function among all tuned hyperparameters. For hyperparameters with
+values that are not specified by the user but tuned during hyperparameter
 optimization.
 optimization.
 
 
 `supplied`::::
 `supplied`::::
 (Boolean)
 (Boolean)
-Indicates if the hyperparameter is specified by the user (`true`) or optimized 
+Indicates if the hyperparameter is specified by the user (`true`) or optimized
 (`false`).
 (`false`).
 
 
 `value`::::
 `value`::::
@@ -602,7 +664,7 @@ Identifier for the trained model.
 `model_type`::
 `model_type`::
 (Optional, string)
 (Optional, string)
 The created model type. By default the model type is `tree_ensemble`.
 The created model type. By default the model type is `tree_ensemble`.
-Appropriate types are: 
+Appropriate types are:
 +
 +
 --
 --
 * `tree_ensemble`: The model definition is an ensemble model of decision trees.
 * `tree_ensemble`: The model definition is an ensemble model of decision trees.

+ 54 - 3
docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc

@@ -377,7 +377,7 @@ A human-readable description of the {infer} trained model.
 `inference_config`::
 `inference_config`::
 (Required, object)
 (Required, object)
 The default configuration for inference. This can be: `regression`,
 The default configuration for inference. This can be: `regression`,
-`classification`, `fill_mask`, `ner`, `text_classification`, or `text_embedding`. 
+`classification`, `fill_mask`, `ner`, `text_classification`, `text_embedding` or `zero_shot_classification`.
 If `regression` or `classification`, it must match the `target_type` of the
 If `regression` or `classification`, it must match the `target_type` of the
 underlying `definition.trained_model`. If `fill_mask`, `ner`,
 underlying `definition.trained_model`. If `fill_mask`, `ner`,
 `text_classification`, or `text_embedding`; the `model_type` must be `pytorch`.
 `text_classification`, or `text_embedding`; the `model_type` must be `pytorch`.
@@ -457,7 +457,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-ner]
 [%collapsible%open]
 [%collapsible%open]
 =====
 =====
 `classification_labels`::::
 `classification_labels`::::
-(Optional, string) 
+(Optional, string)
 An array of classification labels. NER only supports Inside-Outside-Beginning labels (IOB)
 An array of classification labels. NER only supports Inside-Outside-Beginning labels (IOB)
 and only persons, organizations, locations, and miscellaneous.
 and only persons, organizations, locations, and miscellaneous.
 Example: ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]
 Example: ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]
@@ -614,6 +614,57 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 (Optional, integer)
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
 
 
+`with_special_tokens`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
+=======
+======
+=====
+`zero_shot_classification`:::
+(Object, optional)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification]
++
+.Properties of zero_shot_classification inference
+[%collapsible%open]
+=====
+`classification_labels`::::
+(Required, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-classification-labels]
+
+`hypothesis_template`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-hypothesis-template]
+
+`labels`::::
+(Optional, array)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-labels]
+
+`multi_label`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-zero-shot-classification-multi-label]
+
+`tokenization`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
++
+.Properties of tokenization
+[%collapsible%open]
+======
+`bert`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
++
+.Properties of bert
+[%collapsible%open]
+=======
+`do_lower_case`::::
+(Optional, boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-do-lower-case]
+
+`max_sequence_length`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-max-sequence-length]
+
 `with_special_tokens`::::
 `with_special_tokens`::::
 (Optional, boolean)
 (Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens]
@@ -660,7 +711,7 @@ An object map that contains metadata about the model.
 `model_type`::
 `model_type`::
 (Optional, string)
 (Optional, string)
 The created model type. By default the model type is `tree_ensemble`.
 The created model type. By default the model type is `tree_ensemble`.
-Appropriate types are: 
+Appropriate types are:
 +
 +
 --
 --
 * `tree_ensemble`: The model definition is an ensemble model of decision trees.
 * `tree_ensemble`: The model definition is an ensemble model of decision trees.

+ 108 - 68
docs/reference/ml/ml-shared.asciidoc

@@ -323,7 +323,7 @@ end::custom-preprocessor[]
 tag::custom-rules[]
 tag::custom-rules[]
 An array of custom rule objects, which enable you to customize the way detectors
 An array of custom rule objects, which enable you to customize the way detectors
 operate. For example, a rule may dictate to the detector conditions under which
 operate. For example, a rule may dictate to the detector conditions under which
-results should be skipped. {kib} refers to custom rules as _job rules_. For more 
+results should be skipped. {kib} refers to custom rules as _job rules_. For more
 examples, see
 examples, see
 {ml-docs}/ml-configuring-detector-custom-rules.html[Customizing detectors with custom rules].
 {ml-docs}/ml-configuring-detector-custom-rules.html[Customizing detectors with custom rules].
 end::custom-rules[]
 end::custom-rules[]
@@ -526,21 +526,21 @@ end::detector-index[]
 tag::dfas-alpha[]
 tag::dfas-alpha[]
 Advanced configuration option. {ml-cap} uses loss guided tree growing, which
 Advanced configuration option. {ml-cap} uses loss guided tree growing, which
 means that the decision trees grow where the regularized loss decreases most
 means that the decision trees grow where the regularized loss decreases most
-quickly. This parameter affects loss calculations by acting as a multiplier of 
-the tree depth. Higher alpha values result in shallower trees and faster 
-training times. By default, this value is calculated during hyperparameter 
-optimization. It must be greater than or equal to zero. 
+quickly. This parameter affects loss calculations by acting as a multiplier of
+the tree depth. Higher alpha values result in shallower trees and faster
+training times. By default, this value is calculated during hyperparameter
+optimization. It must be greater than or equal to zero.
 end::dfas-alpha[]
 end::dfas-alpha[]
 
 
 tag::dfas-downsample-factor[]
 tag::dfas-downsample-factor[]
-Advanced configuration option. Controls the fraction of data that is used to 
-compute the derivatives of the loss function for tree training. A small value 
-results in the use of a small fraction of the data. If this value is set to be 
-less than 1, accuracy typically improves. However, too small a value may result 
+Advanced configuration option. Controls the fraction of data that is used to
+compute the derivatives of the loss function for tree training. A small value
+results in the use of a small fraction of the data. If this value is set to be
+less than 1, accuracy typically improves. However, too small a value may result
 in poor convergence for the ensemble and so require more trees. For more
 in poor convergence for the ensemble and so require more trees. For more
 information about shrinkage, refer to
 information about shrinkage, refer to
 {wikipedia}/Gradient_boosting#Stochastic_gradient_boosting[this wiki article].
 {wikipedia}/Gradient_boosting#Stochastic_gradient_boosting[this wiki article].
-By default, this value is calculated during hyperparameter optimization. It 
+By default, this value is calculated during hyperparameter optimization. It
 must be greater than zero and less than or equal to 1.
 must be greater than zero and less than or equal to 1.
 end::dfas-downsample-factor[]
 end::dfas-downsample-factor[]
 
 
@@ -553,9 +553,9 @@ By default, early stoppping is enabled.
 end::dfas-early-stopping-enabled[]
 end::dfas-early-stopping-enabled[]
 
 
 tag::dfas-eta-growth[]
 tag::dfas-eta-growth[]
-Advanced configuration option. Specifies the rate at which `eta` increases for 
-each new tree that is added to the forest. For example, a rate of 1.05 
-increases `eta` by 5% for each extra tree. By default, this value is calculated 
+Advanced configuration option. Specifies the rate at which `eta` increases for
+each new tree that is added to the forest. For example, a rate of 1.05
+increases `eta` by 5% for each extra tree. By default, this value is calculated
 during hyperparameter optimization. It must be between 0.5 and 2.
 during hyperparameter optimization. It must be between 0.5 and 2.
 end::dfas-eta-growth[]
 end::dfas-eta-growth[]
 
 
@@ -565,16 +565,16 @@ candidate split.
 end::dfas-feature-bag-fraction[]
 end::dfas-feature-bag-fraction[]
 
 
 tag::dfas-feature-processors[]
 tag::dfas-feature-processors[]
-Advanced configuration option. A collection of feature preprocessors that modify 
-one or more included fields. The analysis uses the resulting one or more 
-features instead of the original document field. However, these features are 
-ephemeral; they are not stored in the destination index. Multiple 
-`feature_processors` entries can refer to the same document fields. Automatic 
-categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs 
+Advanced configuration option. A collection of feature preprocessors that modify
+one or more included fields. The analysis uses the resulting one or more
+features instead of the original document field. However, these features are
+ephemeral; they are not stored in the destination index. Multiple
+`feature_processors` entries can refer to the same document fields. Automatic
+categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs
 for the fields that are unprocessed by a custom processor or that have
 for the fields that are unprocessed by a custom processor or that have
-categorical values. Use this property only if you want to override the automatic 
-feature encoding of the specified fields. Refer to 
-{ml-docs}/ml-feature-processors.html[{dfanalytics} feature processors] to learn 
+categorical values. Use this property only if you want to override the automatic
+feature encoding of the specified fields. Refer to
+{ml-docs}/ml-feature-processors.html[{dfanalytics} feature processors] to learn
 more.
 more.
 end::dfas-feature-processors[]
 end::dfas-feature-processors[]
 
 
@@ -591,13 +591,13 @@ The configuration information necessary to perform frequency encoding.
 end::dfas-feature-processors-frequency[]
 end::dfas-feature-processors-frequency[]
 
 
 tag::dfas-feature-processors-frequency-map[]
 tag::dfas-feature-processors-frequency-map[]
-The resulting frequency map for the field value. If the field value is missing 
+The resulting frequency map for the field value. If the field value is missing
 from the `frequency_map`, the resulting value is `0`.
 from the `frequency_map`, the resulting value is `0`.
 end::dfas-feature-processors-frequency-map[]
 end::dfas-feature-processors-frequency-map[]
 
 
 tag::dfas-feature-processors-multi[]
 tag::dfas-feature-processors-multi[]
-The configuration information necessary to perform multi encoding. It allows 
-multiple processors to be changed together. This way the output of a processor 
+The configuration information necessary to perform multi encoding. It allows
+multiple processors to be changed together. This way the output of a processor
 can then be passed to another as an input.
 can then be passed to another as an input.
 end::dfas-feature-processors-multi[]
 end::dfas-feature-processors-multi[]
 
 
@@ -606,10 +606,10 @@ The ordered array of custom processors to execute. Must be more than 1.
 end::dfas-feature-processors-multi-proc[]
 end::dfas-feature-processors-multi-proc[]
 
 
 tag::dfas-feature-processors-ngram[]
 tag::dfas-feature-processors-ngram[]
-The configuration information necessary to perform n-gram encoding. Features 
-created by this encoder have the following name format: 
-`<feature_prefix>.<ngram><string position>`. For example, if the 
-`feature_prefix` is `f`, the feature name for the second unigram in a string is 
+The configuration information necessary to perform n-gram encoding. Features
+created by this encoder have the following name format:
+`<feature_prefix>.<ngram><string position>`. For example, if the
+`feature_prefix` is `f`, the feature name for the second unigram in a string is
 `f.11`.
 `f.11`.
 end::dfas-feature-processors-ngram[]
 end::dfas-feature-processors-ngram[]
 
 
@@ -622,17 +622,17 @@ The name of the text field to encode.
 end::dfas-feature-processors-ngram-field[]
 end::dfas-feature-processors-ngram-field[]
 
 
 tag::dfas-feature-processors-ngram-length[]
 tag::dfas-feature-processors-ngram-length[]
-Specifies the length of the n-gram substring. Defaults to `50`. Must be greater 
+Specifies the length of the n-gram substring. Defaults to `50`. Must be greater
 than `0`.
 than `0`.
 end::dfas-feature-processors-ngram-length[]
 end::dfas-feature-processors-ngram-length[]
 
 
 tag::dfas-feature-processors-ngram-ngrams[]
 tag::dfas-feature-processors-ngram-ngrams[]
-Specifies which n-grams to gather. It’s an array of integer values where the 
+Specifies which n-grams to gather. It’s an array of integer values where the
 minimum value is 1, and a maximum value is 5.
 minimum value is 1, and a maximum value is 5.
 end::dfas-feature-processors-ngram-ngrams[]
 end::dfas-feature-processors-ngram-ngrams[]
 
 
 tag::dfas-feature-processors-ngram-start[]
 tag::dfas-feature-processors-ngram-start[]
-Specifies the zero-indexed start of the n-gram substring. Negative values are 
+Specifies the zero-indexed start of the n-gram substring. Negative values are
 allowed for encoding n-grams of string suffixes. Defaults to `0`.
 allowed for encoding n-grams of string suffixes. Defaults to `0`.
 end::dfas-feature-processors-ngram-start[]
 end::dfas-feature-processors-ngram-start[]
 
 
@@ -686,19 +686,19 @@ decision tree when the tree is trained.
 end::dfas-num-splits[]
 end::dfas-num-splits[]
 
 
 tag::dfas-soft-limit[]
 tag::dfas-soft-limit[]
-Advanced configuration option. {ml-cap} uses loss guided tree growing, which 
-means that the decision trees grow where the regularized loss decreases most 
-quickly. This soft limit combines with the `soft_tree_depth_tolerance` to 
-penalize trees that exceed the specified depth; the regularized loss increases 
-quickly beyond this depth. By default, this value is calculated during 
+Advanced configuration option. {ml-cap} uses loss guided tree growing, which
+means that the decision trees grow where the regularized loss decreases most
+quickly. This soft limit combines with the `soft_tree_depth_tolerance` to
+penalize trees that exceed the specified depth; the regularized loss increases
+quickly beyond this depth. By default, this value is calculated during
 hyperparameter optimization. It must be greater than or equal to 0.
 hyperparameter optimization. It must be greater than or equal to 0.
 end::dfas-soft-limit[]
 end::dfas-soft-limit[]
 
 
 tag::dfas-soft-tolerance[]
 tag::dfas-soft-tolerance[]
-Advanced configuration option. This option controls how quickly the regularized 
-loss increases when the tree depth exceeds `soft_tree_depth_limit`. By default, 
-this value is calculated during hyperparameter optimization. It must be greater 
-than or equal to 0.01. 
+Advanced configuration option. This option controls how quickly the regularized
+loss increases when the tree depth exceeds `soft_tree_depth_limit`. By default,
+this value is calculated during hyperparameter optimization. It must be greater
+than or equal to 0.01.
 end::dfas-soft-tolerance[]
 end::dfas-soft-tolerance[]
 
 
 tag::dfas-timestamp[]
 tag::dfas-timestamp[]
@@ -744,7 +744,7 @@ end::empty-bucket-count[]
 tag::eta[]
 tag::eta[]
 Advanced configuration option. The shrinkage applied to the weights. Smaller
 Advanced configuration option. The shrinkage applied to the weights. Smaller
 values result in larger forests which have a better generalization error.
 values result in larger forests which have a better generalization error.
-However, larger forests cause slower training. For more information about 
+However, larger forests cause slower training. For more information about
 shrinkage, refer to
 shrinkage, refer to
 {wikipedia}/Gradient_boosting#Shrinkage[this wiki article].
 {wikipedia}/Gradient_boosting#Shrinkage[this wiki article].
 By default, this value is calculated during hyperparameter optimization. It must
 By default, this value is calculated during hyperparameter optimization. It must
@@ -833,10 +833,10 @@ end::function[]
 
 
 tag::gamma[]
 tag::gamma[]
 Advanced configuration option. Regularization parameter to prevent overfitting
 Advanced configuration option. Regularization parameter to prevent overfitting
-on the training data set. Multiplies a linear penalty associated with the size 
-of individual trees in the forest. A high gamma value causes training to prefer 
-small trees. A small gamma value results in larger individual trees and slower 
-training. By default, this value is calculated during hyperparameter 
+on the training data set. Multiplies a linear penalty associated with the size
+of individual trees in the forest. A high gamma value causes training to prefer
+small trees. A small gamma value results in larger individual trees and slower
+training. By default, this value is calculated during hyperparameter
 optimization. It must be a nonnegative value.
 optimization. It must be a nonnegative value.
 end::gamma[]
 end::gamma[]
 
 
@@ -849,7 +849,7 @@ An array of index names. Wildcards are supported. For example:
 `["it_ops_metrics", "server*"]`.
 `["it_ops_metrics", "server*"]`.
 +
 +
 --
 --
-NOTE: If any indices are in remote clusters then the {ml} nodes need to have the 
+NOTE: If any indices are in remote clusters then the {ml} nodes need to have the
 `remote_cluster_client` role.
 `remote_cluster_client` role.
 
 
 --
 --
@@ -921,7 +921,7 @@ BERT-style tokenization is to be performed with the enclosed settings.
 end::inference-config-nlp-tokenization-bert[]
 end::inference-config-nlp-tokenization-bert[]
 
 
 tag::inference-config-nlp-tokenization-bert-do-lower-case[]
 tag::inference-config-nlp-tokenization-bert-do-lower-case[]
-Should the tokenization lower case the text sequence when building 
+Should the tokenization lower case the text sequence when building
 the tokens.
 the tokens.
 end::inference-config-nlp-tokenization-bert-do-lower-case[]
 end::inference-config-nlp-tokenization-bert-do-lower-case[]
 
 
@@ -930,7 +930,7 @@ Tokenize with special tokens. The tokens typically included in BERT-style tokeni
 +
 +
 --
 --
 * `[CLS]`: The first token of the sequence being classified.
 * `[CLS]`: The first token of the sequence being classified.
-* `[SEP]`: Indicates sequence separation. 
+* `[SEP]`: Indicates sequence separation.
 --
 --
 end::inference-config-nlp-tokenization-bert-with-special-tokens[]
 end::inference-config-nlp-tokenization-bert-with-special-tokens[]
 
 
@@ -998,6 +998,46 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
 used to train the model, which defaults to `<dependent_variable>_prediction`.
 used to train the model, which defaults to `<dependent_variable>_prediction`.
 end::inference-config-results-field-processor[]
 end::inference-config-results-field-processor[]
 
 
+tag::inference-config-zero-shot-classification[]
+Configures a zero-shot classification task. Zero-shot classification allows for
+text classification to occur without pre-determined labels. At inference time,
+it is possible to adjust the labels to classify. This makes this type of model
+and task exceptionally flexible.
+
+If consistently classifying the same labels, it may be better to use a fine turned
+text classification model.
+end::inference-config-zero-shot-classification[]
+
+tag::inference-config-zero-shot-classification-classification-labels[]
+The classification labels used during the zero-shot classification. Classification
+labels must not be empty or null and only set at model creation. They must be all three
+of ["entailment", "neutral", "contradiction"].
+
+NOTE: This is NOT the same as `labels` which are the values that zero-shot is attempting to
+      classify.
+end::inference-config-zero-shot-classification-classification-labels[]
+
+tag::inference-config-zero-shot-classification-hypothesis-template[]
+This is the template used when tokenizing the sequences for classification.
+
+The labels replace the `{}` value in the text. The default value is:
+`This example is {}.`
+end::inference-config-zero-shot-classification-hypothesis-template[]
+
+tag::inference-config-zero-shot-classification-labels[]
+The labels to classify. Can be set at creation for default labels, and
+then updated during inference.
+end::inference-config-zero-shot-classification-labels[]
+
+tag::inference-config-zero-shot-classification-multi-label[]
+Indicates if more than one `true` label is possible given the input.
+
+This is useful when labeling text that could pertain to more than one of the
+input labels.
+
+Defaults to `false`.
+end::inference-config-zero-shot-classification-multi-label[]
+
 tag::inference-metadata-feature-importance-feature-name[]
 tag::inference-metadata-feature-importance-feature-name[]
 The feature for which this importance was calculated.
 The feature for which this importance was calculated.
 end::inference-metadata-feature-importance-feature-name[]
 end::inference-metadata-feature-importance-feature-name[]
@@ -1102,11 +1142,11 @@ end::job-id-datafeed[]
 tag::lambda[]
 tag::lambda[]
 Advanced configuration option. Regularization parameter to prevent overfitting
 Advanced configuration option. Regularization parameter to prevent overfitting
 on the training data set. Multiplies an L2 regularization term which applies to
 on the training data set. Multiplies an L2 regularization term which applies to
-leaf weights of the individual trees in the forest. A high lambda value causes 
-training to favor small leaf weights. This behavior makes the prediction 
+leaf weights of the individual trees in the forest. A high lambda value causes
+training to favor small leaf weights. This behavior makes the prediction
 function smoother at the expense of potentially not being able to capture
 function smoother at the expense of potentially not being able to capture
 relevant relationships between the features and the {depvar}. A small lambda
 relevant relationships between the features and the {depvar}. A small lambda
-value results in large individual trees and slower training. By default, this 
+value results in large individual trees and slower training. By default, this
 value is calculated during hyperparameter optimization. It must be a nonnegative
 value is calculated during hyperparameter optimization. It must be a nonnegative
 value.
 value.
 end::lambda[]
 end::lambda[]
@@ -1151,13 +1191,13 @@ set.
 end::max-empty-searches[]
 end::max-empty-searches[]
 
 
 tag::max-trees[]
 tag::max-trees[]
-Advanced configuration option. Defines the maximum number of decision trees in 
-the forest. The maximum value is 2000. By default, this value is calculated 
+Advanced configuration option. Defines the maximum number of decision trees in
+the forest. The maximum value is 2000. By default, this value is calculated
 during hyperparameter optimization.
 during hyperparameter optimization.
 end::max-trees[]
 end::max-trees[]
 
 
 tag::max-trees-trained-models[]
 tag::max-trees-trained-models[]
-The maximum number of decision trees in the forest. The maximum value is 2000. 
+The maximum number of decision trees in the forest. The maximum value is 2000.
 By default, this value is calculated during hyperparameter optimization.
 By default, this value is calculated during hyperparameter optimization.
 end::max-trees-trained-models[]
 end::max-trees-trained-models[]
 
 
@@ -1222,7 +1262,7 @@ default value for jobs created in version 6.1 and later is `1024mb`. If the
 than `1024mb`, however, that value is used instead. The default value is
 than `1024mb`, however, that value is used instead. The default value is
 relatively small to ensure that high resource usage is a conscious decision. If
 relatively small to ensure that high resource usage is a conscious decision. If
 you have jobs that are expected to analyze high cardinality fields, you will
 you have jobs that are expected to analyze high cardinality fields, you will
-likely need to use a higher value. 
+likely need to use a higher value.
 +
 +
 If you specify a number instead of a string, the units are assumed to be MiB.
 If you specify a number instead of a string, the units are assumed to be MiB.
 Specifying a string is recommended for clarity. If you specify a byte size unit
 Specifying a string is recommended for clarity. If you specify a byte size unit
@@ -1299,11 +1339,11 @@ Only the specified `terms` can be viewed when using the Single Metric Viewer.
 end::model-plot-config-terms[]
 end::model-plot-config-terms[]
 
 
 tag::model-prune-window[]
 tag::model-prune-window[]
-Advanced configuration option. 
-Affects the pruning of models that have not been updated for the given time 
-duration. The value must be set to a multiple of the `bucket_span`. If set too 
-low, important information may be removed from the model. Typically, set to 
-`30d` or longer. If not set, model pruning only occurs if the model memory 
+Advanced configuration option.
+Affects the pruning of models that have not been updated for the given time
+duration. The value must be set to a multiple of the `bucket_span`. If set too
+low, important information may be removed from the model. Typically, set to
+`30d` or longer. If not set, model pruning only occurs if the model memory
 status reaches the soft limit or the hard limit.
 status reaches the soft limit or the hard limit.
 end::model-prune-window[]
 end::model-prune-window[]
 
 
@@ -1391,10 +1431,10 @@ end::open-time[]
 
 
 tag::out-of-order-timestamp-count[]
 tag::out-of-order-timestamp-count[]
 The number of input documents that have a timestamp chronologically
 The number of input documents that have a timestamp chronologically
-preceding the start of the current anomaly detection bucket offset by 
-the latency window. This information is applicable only when you provide 
-data to the {anomaly-job} by using the <<ml-post-data,post data API>>. 
-These out of order documents are discarded, since jobs require time 
+preceding the start of the current anomaly detection bucket offset by
+the latency window. This information is applicable only when you provide
+data to the {anomaly-job} by using the <<ml-post-data,post data API>>.
+These out of order documents are discarded, since jobs require time
 series data to be in ascending chronological order.
 series data to be in ascending chronological order.
 end::out-of-order-timestamp-count[]
 end::out-of-order-timestamp-count[]
 
 
@@ -1459,9 +1499,9 @@ number of {es} documents.
 end::processed-record-count[]
 end::processed-record-count[]
 
 
 tag::randomize-seed[]
 tag::randomize-seed[]
-Defines the seed for the random generator that is used to pick training data. By 
-default, it is randomly generated. Set it to a specific value to use the same 
-training data each time you start a job (assuming other related parameters such 
+Defines the seed for the random generator that is used to pick training data. By
+default, it is randomly generated. Set it to a specific value to use the same
+training data each time you start a job (assuming other related parameters such
 as `source` and `analyzed_fields` are the same).
 as `source` and `analyzed_fields` are the same).
 end::randomize-seed[]
 end::randomize-seed[]
 
 

+ 28 - 21
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -11,19 +11,19 @@ import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.tasks.BaseTasksRequest;
 import org.elasticsearch.action.support.tasks.BaseTasksRequest;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.ParseField;
-import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 
 import java.io.IOException;
 import java.io.IOException;
@@ -31,6 +31,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
+import java.util.Optional;
 
 
 import static org.elasticsearch.action.ValidateActions.addValidationError;
 import static org.elasticsearch.action.ValidateActions.addValidationError;
 
 
@@ -45,11 +46,12 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
         super(NAME, InferTrainedModelDeploymentAction.Response::new);
         super(NAME, InferTrainedModelDeploymentAction.Response::new);
     }
     }
 
 
-    public static class Request extends BaseTasksRequest<Request> implements ToXContentObject {
+    public static class Request extends BaseTasksRequest<Request> {
 
 
         public static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id");
         public static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id");
         public static final ParseField DOCS = new ParseField("docs");
         public static final ParseField DOCS = new ParseField("docs");
         public static final ParseField TIMEOUT = new ParseField("timeout");
         public static final ParseField TIMEOUT = new ParseField("timeout");
+        public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
 
 
         public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(10);
         public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(10);
 
 
@@ -58,6 +60,11 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             PARSER.declareString(Request.Builder::setDeploymentId, DEPLOYMENT_ID);
             PARSER.declareString(Request.Builder::setDeploymentId, DEPLOYMENT_ID);
             PARSER.declareObjectArray(Request.Builder::setDocs, (p, c) -> p.mapOrdered(), DOCS);
             PARSER.declareObjectArray(Request.Builder::setDocs, (p, c) -> p.mapOrdered(), DOCS);
             PARSER.declareString(Request.Builder::setTimeout, TIMEOUT);
             PARSER.declareString(Request.Builder::setTimeout, TIMEOUT);
+            PARSER.declareNamedObject(
+                Request.Builder::setUpdate,
+                ((p, c, name) -> p.namedObject(InferenceConfigUpdate.class, name, c)),
+                INFERENCE_CONFIG
+            );
         }
         }
 
 
         public static Request parseRequest(String deploymentId, XContentParser parser) {
         public static Request parseRequest(String deploymentId, XContentParser parser) {
@@ -70,16 +77,19 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
 
 
         private final String deploymentId;
         private final String deploymentId;
         private final List<Map<String, Object>> docs;
         private final List<Map<String, Object>> docs;
+        private final InferenceConfigUpdate update;
 
 
-        public Request(String deploymentId, List<Map<String, Object>> docs) {
+        public Request(String deploymentId, InferenceConfigUpdate update, List<Map<String, Object>> docs) {
             this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, DEPLOYMENT_ID);
             this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, DEPLOYMENT_ID);
             this.docs = ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS);
             this.docs = ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS);
+            this.update = update;
         }
         }
 
 
         public Request(StreamInput in) throws IOException {
         public Request(StreamInput in) throws IOException {
             super(in);
             super(in);
             deploymentId = in.readString();
             deploymentId = in.readString();
             docs = Collections.unmodifiableList(in.readList(StreamInput::readMap));
             docs = Collections.unmodifiableList(in.readList(StreamInput::readMap));
+            update = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
         }
         }
 
 
         public String getDeploymentId() {
         public String getDeploymentId() {
@@ -90,6 +100,10 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return docs;
             return docs;
         }
         }
 
 
+        public InferenceConfigUpdate getUpdate() {
+            return Optional.ofNullable(update).orElse(new EmptyConfigUpdate());
+        }
+
         @Override
         @Override
         public TimeValue getTimeout() {
         public TimeValue getTimeout() {
             TimeValue tv = super.getTimeout();
             TimeValue tv = super.getTimeout();
@@ -124,16 +138,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             super.writeTo(out);
             super.writeTo(out);
             out.writeString(deploymentId);
             out.writeString(deploymentId);
             out.writeCollection(docs, StreamOutput::writeMap);
             out.writeCollection(docs, StreamOutput::writeMap);
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
-            builder.startObject();
-            builder.field(DEPLOYMENT_ID.getPreferredName(), deploymentId);
-            builder.field(DOCS.getPreferredName(), docs);
-            builder.field(TIMEOUT.getPreferredName(), getTimeout().getStringRep());
-            builder.endObject();
-            return builder;
+            out.writeOptionalNamedWriteable(update);
         }
         }
 
 
         @Override
         @Override
@@ -148,17 +153,13 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             InferTrainedModelDeploymentAction.Request that = (InferTrainedModelDeploymentAction.Request) o;
             InferTrainedModelDeploymentAction.Request that = (InferTrainedModelDeploymentAction.Request) o;
             return Objects.equals(deploymentId, that.deploymentId)
             return Objects.equals(deploymentId, that.deploymentId)
                 && Objects.equals(docs, that.docs)
                 && Objects.equals(docs, that.docs)
+                && Objects.equals(update, that.update)
                 && Objects.equals(getTimeout(), that.getTimeout());
                 && Objects.equals(getTimeout(), that.getTimeout());
         }
         }
 
 
         @Override
         @Override
         public int hashCode() {
         public int hashCode() {
-            return Objects.hash(deploymentId, docs, getTimeout());
-        }
-
-        @Override
-        public String toString() {
-            return Strings.toString(this);
+            return Objects.hash(deploymentId, update, docs, getTimeout());
         }
         }
 
 
         public static class Builder {
         public static class Builder {
@@ -166,6 +167,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             private String deploymentId;
             private String deploymentId;
             private List<Map<String, Object>> docs;
             private List<Map<String, Object>> docs;
             private TimeValue timeout;
             private TimeValue timeout;
+            private InferenceConfigUpdate update;
 
 
             private Builder() {}
             private Builder() {}
 
 
@@ -184,12 +186,17 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 return this;
                 return this;
             }
             }
 
 
+            public Builder setUpdate(InferenceConfigUpdate update) {
+                this.update = update;
+                return this;
+            }
+
             private Builder setTimeout(String timeout) {
             private Builder setTimeout(String timeout) {
                 return setTimeout(TimeValue.parseTimeValue(timeout, TIMEOUT.getPreferredName()));
                 return setTimeout(TimeValue.parseTimeValue(timeout, TIMEOUT.getPreferredName()));
             }
             }
 
 
             public Request build() {
             public Request build() {
-                Request request = new Request(deploymentId, docs);
+                Request request = new Request(deploymentId, update, docs);
                 if (timeout != null) {
                 if (timeout != null) {
                     request.setTimeout(timeout);
                     request.setTimeout(timeout);
                 }
                 }

+ 18 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -52,6 +52,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfi
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
@@ -184,11 +186,23 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             new ParseField(TextEmbeddingConfig.NAME), TextEmbeddingConfig::fromXContentLenient));
             new ParseField(TextEmbeddingConfig.NAME), TextEmbeddingConfig::fromXContentLenient));
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(TextEmbeddingConfig.NAME),
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(TextEmbeddingConfig.NAME),
             TextEmbeddingConfig::fromXContentStrict));
             TextEmbeddingConfig::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
+            new ParseField(ZeroShotClassificationConfig.NAME), ZeroShotClassificationConfig::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class,
+            new ParseField(ZeroShotClassificationConfig.NAME),
+            ZeroShotClassificationConfig::fromXContentStrict));
 
 
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
             ClassificationConfigUpdate::fromXContentStrict));
             ClassificationConfigUpdate::fromXContentStrict));
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
             RegressionConfigUpdate::fromXContentStrict));
             RegressionConfigUpdate::fromXContentStrict));
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                InferenceConfigUpdate.class,
+                new ParseField(ZeroShotClassificationConfigUpdate.NAME),
+                ZeroShotClassificationConfigUpdate::fromXContentStrict
+            )
+        );
 
 
         // Inference models
         // Inference models
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
@@ -288,6 +302,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             PassThroughConfig.NAME, PassThroughConfig::new));
             PassThroughConfig.NAME, PassThroughConfig::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
             TextEmbeddingConfig.NAME, TextEmbeddingConfig::new));
             TextEmbeddingConfig.NAME, TextEmbeddingConfig::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
+            ZeroShotClassificationConfig.NAME, ZeroShotClassificationConfig::new));
 
 
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));
             ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));
@@ -297,6 +313,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             ResultsFieldUpdate.NAME, ResultsFieldUpdate::new));
             ResultsFieldUpdate.NAME, ResultsFieldUpdate::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             EmptyConfigUpdate.NAME, EmptyConfigUpdate::new));
             EmptyConfigUpdate.NAME, EmptyConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            ZeroShotClassificationConfigUpdate.NAME, ZeroShotClassificationConfigUpdate::new));
 
 
         // Location
         // Location
         namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModelLocation.class,
         namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModelLocation.class,

+ 21 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java

@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.xcontent.ParseField;
+
+public abstract class NlpConfigUpdate implements InferenceConfigUpdate {
+
+    static ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
+
+    @Override
+    public InferenceConfig toConfig() {
+        throw new UnsupportedOperationException("cannot serialize to nodes before 7.8");
+    }
+
+}

+ 245 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java

@@ -0,0 +1,245 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+/**
+ * This builds out a 0-shot classification task.
+ *
+ * The 0-shot methodology assumed is MNLI optimized task. For further info see: https://arxiv.org/abs/1909.00161
+ *
+ */
+public class ZeroShotClassificationConfig implements NlpConfig {
+
+    public static final String NAME = "zero_shot_classification";
+    public static final ParseField HYPOTHESIS_TEMPLATE = new ParseField("hypothesis_template");
+    public static final ParseField MULTI_LABEL = new ParseField("multi_label");
+    public static final ParseField LABELS = new ParseField("labels");
+
+    public static ZeroShotClassificationConfig fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    public static ZeroShotClassificationConfig fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null);
+    }
+
+    private static final Set<String> REQUIRED_CLASSIFICATION_LABELS = new TreeSet<>(List.of("entailment", "neutral", "contradiction"));
+    private static final String DEFAULT_HYPOTHESIS_TEMPLATE = "This example is {}.";
+    private static final ConstructingObjectParser<ZeroShotClassificationConfig, Void> STRICT_PARSER = createParser(false);
+    private static final ConstructingObjectParser<ZeroShotClassificationConfig, Void> LENIENT_PARSER = createParser(true);
+
+    @SuppressWarnings({ "unchecked"})
+    private static ConstructingObjectParser<ZeroShotClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<ZeroShotClassificationConfig, Void> parser = new ConstructingObjectParser<>(
+            NAME,
+            ignoreUnknownFields,
+            a -> new ZeroShotClassificationConfig(
+                (List<String>)a[0],
+                (VocabularyConfig) a[1],
+                (Tokenization) a[2],
+                (String) a[3],
+                (Boolean) a[4],
+                (List<String>) a[5]
+            )
+        );
+        parser.declareStringArray(ConstructingObjectParser.constructorArg(), CLASSIFICATION_LABELS);
+        parser.declareObject(
+            ConstructingObjectParser.optionalConstructorArg(),
+            (p, c) -> {
+                if (ignoreUnknownFields == false) {
+                    throw ExceptionsHelper.badRequestException(
+                        "illegal setting [{}] on inference model creation",
+                        VOCABULARY.getPreferredName()
+                    );
+                }
+                return VocabularyConfig.fromXContentLenient(p);
+            },
+            VOCABULARY
+        );
+        parser.declareNamedObject(
+            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+                TOKENIZATION
+        );
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), HYPOTHESIS_TEMPLATE);
+        parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
+        parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
+        return parser;
+    }
+
+    private final VocabularyConfig vocabularyConfig;
+    private final Tokenization tokenization;
+    private final List<String> classificationLabels;
+    private final List<String> labels;
+    private final boolean isMultiLabel;
+    private final String hypothesisTemplate;
+
+    public ZeroShotClassificationConfig(
+        List<String> classificationLabels,
+        @Nullable VocabularyConfig vocabularyConfig,
+        @Nullable Tokenization tokenization,
+        @Nullable String hypothesisTemplate,
+        @Nullable Boolean isMultiLabel,
+        @Nullable List<String> labels
+    ) {
+        this.classificationLabels = ExceptionsHelper.requireNonNull(classificationLabels, CLASSIFICATION_LABELS);
+        if (this.classificationLabels.size() != 3) {
+            throw ExceptionsHelper.badRequestException(
+                "[{}] must contain exactly the three values {}",
+                CLASSIFICATION_LABELS.getPreferredName(),
+                REQUIRED_CLASSIFICATION_LABELS
+            );
+        }
+        List<String> badLabels = classificationLabels.stream()
+            .map(s -> s.toLowerCase(Locale.ROOT))
+            .filter(c -> REQUIRED_CLASSIFICATION_LABELS.contains(c) == false)
+            .collect(Collectors.toList());
+        if (badLabels.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException(
+                "[{}] must contain exactly the three values {}. Invalid labels {}",
+                CLASSIFICATION_LABELS.getPreferredName(),
+                REQUIRED_CLASSIFICATION_LABELS,
+                badLabels
+            );
+        }
+        this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
+            .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
+        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
+        this.isMultiLabel = isMultiLabel != null && isMultiLabel;
+        this.hypothesisTemplate = Optional.ofNullable(hypothesisTemplate).orElse(DEFAULT_HYPOTHESIS_TEMPLATE);
+        this.labels = labels;
+        if (labels != null && labels.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
+        }
+    }
+
+    public ZeroShotClassificationConfig(StreamInput in) throws IOException {
+        vocabularyConfig = new VocabularyConfig(in);
+        tokenization = in.readNamedWriteable(Tokenization.class);
+        classificationLabels = in.readStringList();
+        isMultiLabel = in.readBoolean();
+        hypothesisTemplate = in.readString();
+        labels = in.readOptionalStringList();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        vocabularyConfig.writeTo(out);
+        out.writeNamedWriteable(tokenization);
+        out.writeStringCollection(classificationLabels);
+        out.writeBoolean(isMultiLabel);
+        out.writeString(hypothesisTemplate);
+        out.writeOptionalStringCollection(labels);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
+        builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        builder.field(MULTI_LABEL.getPreferredName(), isMultiLabel);
+        builder.field(HYPOTHESIS_TEMPLATE.getPreferredName(), hypothesisTemplate);
+        if (labels != null) {
+            builder.field(LABELS.getPreferredName(), labels);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean isTargetTypeSupported(TargetType targetType) {
+        return false;
+    }
+
+    @Override
+    public Version getMinimalSupportedVersion() {
+        return Version.V_8_0_0;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == this) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        ZeroShotClassificationConfig that = (ZeroShotClassificationConfig) o;
+        return Objects.equals(vocabularyConfig, that.vocabularyConfig)
+            && Objects.equals(tokenization, that.tokenization)
+            && Objects.equals(isMultiLabel, that.isMultiLabel)
+            && Objects.equals(hypothesisTemplate, that.hypothesisTemplate)
+            && Objects.equals(labels, that.labels)
+            && Objects.equals(classificationLabels, that.classificationLabels);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, hypothesisTemplate, isMultiLabel, labels);
+    }
+
+    @Override
+    public VocabularyConfig getVocabularyConfig() {
+        return vocabularyConfig;
+    }
+
+    @Override
+    public Tokenization getTokenization() {
+        return tokenization;
+    }
+
+    public List<String> getClassificationLabels() {
+        return classificationLabels;
+    }
+
+    public boolean isMultiLabel() {
+        return isMultiLabel;
+    }
+
+    public String getHypothesisTemplate() {
+        return hypothesisTemplate;
+    }
+
+    public List<String> getLabels() {
+        return Optional.ofNullable(labels).orElse(List.of());
+    }
+
+    @Override
+    public boolean isAllocateOnly() {
+        return true;
+    }
+
+}

+ 201 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java

@@ -0,0 +1,201 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.LABELS;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.MULTI_LABEL;
+
+public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+
+    public static final String NAME = "zero_shot_classification";
+
+    public static ZeroShotClassificationConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    @SuppressWarnings({ "unchecked"})
+    public static ZeroShotClassificationConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        Boolean isMultiLabel = (Boolean)options.remove(MULTI_LABEL.getPreferredName());
+        List<String> labels = (List<String>)options.remove(LABELS.getPreferredName());
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
+        }
+        return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel);
+    }
+
+    @SuppressWarnings({ "unchecked"})
+    private static final ConstructingObjectParser<ZeroShotClassificationConfigUpdate, Void> STRICT_PARSER = new ConstructingObjectParser<>(
+        NAME,
+        a -> new ZeroShotClassificationConfigUpdate((List<String>)a[0], (Boolean) a[1])
+    );
+
+    static {
+        STRICT_PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
+        STRICT_PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
+    }
+
+    private final List<String> labels;
+    private final Boolean isMultiLabel;
+
+    public ZeroShotClassificationConfigUpdate(
+        @Nullable List<String> labels,
+        @Nullable Boolean isMultiLabel
+    ) {
+        this.labels = labels;
+        if (labels != null && labels.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
+        }
+        this.isMultiLabel = isMultiLabel;
+    }
+
+    public ZeroShotClassificationConfigUpdate(StreamInput in) throws IOException {
+        labels = in.readOptionalStringList();
+        isMultiLabel = in.readOptionalBoolean();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalStringCollection(labels);
+        out.writeOptionalBoolean(isMultiLabel);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (labels != null) {
+            builder.field(LABELS.getPreferredName(), labels);
+        }
+        if (isMultiLabel != null) {
+            builder.field(MULTI_LABEL.getPreferredName(), isMultiLabel);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (originalConfig instanceof ZeroShotClassificationConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a inference request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig)originalConfig;
+        if ((labels == null || labels.isEmpty()) && (zeroShotConfig.getLabels() == null || zeroShotConfig.getLabels().isEmpty())) {
+            throw ExceptionsHelper.badRequestException(
+                "stored configuration has no [{}] defined, supplied inference_config update must supply [{}]",
+                LABELS.getPreferredName(),
+                LABELS.getPreferredName()
+            );
+        }
+        if (isNoop(zeroShotConfig)) {
+            return originalConfig;
+        }
+        return new ZeroShotClassificationConfig(
+            zeroShotConfig.getClassificationLabels(),
+            zeroShotConfig.getVocabularyConfig(),
+            zeroShotConfig.getTokenization(),
+            zeroShotConfig.getHypothesisTemplate(),
+            Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()),
+            Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels())
+        );
+    }
+
+    boolean isNoop(ZeroShotClassificationConfig originalConfig) {
+        return (labels == null || labels.equals(originalConfig.getClassificationLabels()))
+            && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()));
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof ZeroShotClassificationConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return null;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder().setLabels(labels).setMultiLabel(isMultiLabel);
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == this) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        ZeroShotClassificationConfigUpdate that = (ZeroShotClassificationConfigUpdate) o;
+        return Objects.equals(isMultiLabel, that.isMultiLabel) && Objects.equals(labels, that.labels);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(labels, isMultiLabel);
+    }
+
+    public List<String> getLabels() {
+        return labels;
+    }
+
+    public static class Builder implements InferenceConfigUpdate.Builder<
+        ZeroShotClassificationConfigUpdate.Builder,
+        ZeroShotClassificationConfigUpdate
+        > {
+        private List<String> labels;
+        private Boolean isMultiLabel;
+
+        @Override
+        public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) {
+            throw new IllegalArgumentException();
+        }
+
+        public Builder setLabels(List<String> labels) {
+            this.labels = labels;
+            return this;
+        }
+
+        public Builder setMultiLabel(Boolean multiLabel) {
+            isMultiLabel = multiLabel;
+            return this;
+        }
+
+        public ZeroShotClassificationConfigUpdate build() {
+            return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel);
+        }
+    }
+}

+ 24 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java

@@ -7,19 +7,24 @@
 
 
 package org.elasticsearch.xpack.core.ml.action;
 package org.elasticsearch.xpack.core.ml.action;
 
 
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.core.Tuple;
-import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;
 
 
-import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 
 
-public class InferTrainedModelDeploymentRequestsTests extends AbstractSerializingTestCase<InferTrainedModelDeploymentAction.Request> {
-    @Override
-    protected InferTrainedModelDeploymentAction.Request doParseInstance(XContentParser parser) throws IOException {
-        return InferTrainedModelDeploymentAction.Request.parseRequest(null, parser);
+public class InferTrainedModelDeploymentRequestsTests extends AbstractWireSerializingTestCase<InferTrainedModelDeploymentAction.Request> {
+
+
+    private static InferenceConfigUpdate randomInferenceConfigUpdate() {
+        return randomFrom(ZeroShotClassificationConfigUpdateTests.createRandom(), EmptyConfigUpdateTests.testInstance());
     }
     }
 
 
     @Override
     @Override
@@ -32,14 +37,24 @@ public class InferTrainedModelDeploymentRequestsTests extends AbstractSerializin
         List<Map<String, Object>> docs = randomList(5, () -> randomMap(1, 3,
         List<Map<String, Object>> docs = randomList(5, () -> randomMap(1, 3,
             () -> Tuple.tuple(randomAlphaOfLength(7), randomAlphaOfLength(7))));
             () -> Tuple.tuple(randomAlphaOfLength(7), randomAlphaOfLength(7))));
 
 
-        InferTrainedModelDeploymentAction.Request request =
-            new InferTrainedModelDeploymentAction.Request(randomAlphaOfLength(4), docs);
+        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
+            randomAlphaOfLength(4),
+            randomBoolean() ? null : randomInferenceConfigUpdate(),
+            docs
+        );
         if (randomBoolean()) {
         if (randomBoolean()) {
             request.setTimeout(randomTimeValue());
             request.setTimeout(randomTimeValue());
         }
         }
         return request;
         return request;
     }
     }
 
 
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
+        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+
     public void testTimeoutNotNull() {
     public void testTimeoutNotNull() {
         assertNotNull(createTestInstance().getTimeout());
         assertNotNull(createTestInstance().getTimeout());
     }
     }

+ 61 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java

@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.function.Predicate;
+
+public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCase<ZeroShotClassificationConfig> {
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> field.isEmpty() == false;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfig doParseInstance(XContentParser parser) throws IOException {
+        return ZeroShotClassificationConfig.fromXContentLenient(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<ZeroShotClassificationConfig> instanceReader() {
+        return ZeroShotClassificationConfig::new;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfig createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected ZeroShotClassificationConfig mutateInstanceForVersion(ZeroShotClassificationConfig instance, Version version) {
+        return instance;
+    }
+
+    public static ZeroShotClassificationConfig createRandom() {
+        return new ZeroShotClassificationConfig(
+            randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomAlphaOfLength(10),
+            randomBoolean(),
+            randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10))
+        );
+    }
+}

+ 134 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

@@ -0,0 +1,134 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+
+public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItemTestCase<ZeroShotClassificationConfigUpdate> {
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return ZeroShotClassificationConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<ZeroShotClassificationConfigUpdate> instanceReader() {
+        return ZeroShotClassificationConfigUpdate::new;
+    }
+
+    @Override
+    protected ZeroShotClassificationConfigUpdate createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected ZeroShotClassificationConfigUpdate mutateInstanceForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
+        return instance;
+    }
+
+    public void testFromMap() {
+        ZeroShotClassificationConfigUpdate expected = new ZeroShotClassificationConfigUpdate(List.of("foo", "bar"), false);
+        Map<String, Object> config = new HashMap<>(){{
+            put(ZeroShotClassificationConfig.LABELS.getPreferredName(), List.of("foo", "bar"));
+            put(ZeroShotClassificationConfig.MULTI_LABEL.getPreferredName(), false);
+        }};
+        assertThat(ZeroShotClassificationConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> ZeroShotClassificationConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+    public void testApply() {
+        ZeroShotClassificationConfig originalConfig = new ZeroShotClassificationConfig(
+            randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomAlphaOfLength(10),
+            randomBoolean(),
+            randomList(1, 5, () -> randomAlphaOfLength(10))
+        );
+
+        assertThat(originalConfig, equalTo(new ZeroShotClassificationConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(
+            new ZeroShotClassificationConfig(
+                originalConfig.getClassificationLabels(),
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                originalConfig.getHypothesisTemplate(),
+                originalConfig.isMultiLabel(),
+                List.of("foo", "bar")
+            ),
+            equalTo(
+                new ZeroShotClassificationConfigUpdate.Builder()
+                    .setLabels(List.of("foo", "bar")).build()
+                    .apply(originalConfig)
+            )
+        );
+        assertThat(
+            new ZeroShotClassificationConfig(
+                originalConfig.getClassificationLabels(),
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                originalConfig.getHypothesisTemplate(),
+                true,
+                originalConfig.getLabels()
+            ),
+            equalTo(
+                new ZeroShotClassificationConfigUpdate.Builder()
+                    .setMultiLabel(true).build()
+                    .apply(originalConfig)
+            )
+        );
+    }
+
+    public void testApplyWithEmptyLabelsInConfigAndUpdate() {
+        ZeroShotClassificationConfig originalConfig = new ZeroShotClassificationConfig(
+            randomFrom(List.of("entailment", "neutral", "contradiction"), List.of("contradiction", "neutral", "entailment")),
+            randomBoolean() ? null : VocabularyConfigTests.createRandom(),
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomAlphaOfLength(10),
+            randomBoolean(),
+            null
+        );
+
+        Exception ex = expectThrows(Exception.class, () -> new ZeroShotClassificationConfigUpdate.Builder().build().apply(originalConfig));
+        assertThat(
+            ex.getMessage(),
+            containsString("stored configuration has no [labels] defined, supplied inference_config update must supply [labels]")
+        );
+    }
+
+    public static ZeroShotClassificationConfigUpdate createRandom() {
+        return new ZeroShotClassificationConfigUpdate(
+            randomBoolean() ? null : randomList(1,5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomBoolean()
+        );
+    }
+}

+ 4 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -80,7 +80,10 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
     @Override
     @Override
     protected void taskOperation(InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task,
     protected void taskOperation(InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task,
                                  ActionListener<InferTrainedModelDeploymentAction.Response> listener) {
                                  ActionListener<InferTrainedModelDeploymentAction.Response> listener) {
-        task.infer(request.getDocs().get(0), request.getTimeout(),
+        task.infer(
+            request.getDocs().get(0),
+            request.getUpdate(),
+            request.getTimeout(),
             ActionListener.wrap(
             ActionListener.wrap(
                 pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)),
                 pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)),
                 listener::onFailure)
                 listener::onFailure)

+ 9 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
 import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
 import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -138,7 +139,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
                 // Always fail immediately and return an error
                 // Always fail immediately and return an error
                 ex -> true);
                 ex -> true);
         request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(
         request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(
-            chainedTask -> inferSingleDocAgainstAllocatedModel(request.getModelId(), stringObjectMap, chainedTask)));
+            chainedTask -> inferSingleDocAgainstAllocatedModel(request.getModelId(), request.getUpdate(), stringObjectMap, chainedTask)));
 
 
         typedChainTaskExecutor.execute(ActionListener.wrap(
         typedChainTaskExecutor.execute(ActionListener.wrap(
             inferenceResults -> listener.onResponse(responseBuilder.setInferenceResults(inferenceResults)
             inferenceResults -> listener.onResponse(responseBuilder.setInferenceResults(inferenceResults)
@@ -148,11 +149,16 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         ));
         ));
     }
     }
 
 
-    private void inferSingleDocAgainstAllocatedModel(String modelId, Map<String, Object> doc, ActionListener<InferenceResults> listener) {
+    private void inferSingleDocAgainstAllocatedModel(
+        String modelId,
+        InferenceConfigUpdate inferenceConfigUpdate,
+        Map<String, Object> doc,
+        ActionListener<InferenceResults> listener
+    ) {
         executeAsyncWithOrigin(client,
         executeAsyncWithOrigin(client,
             ML_ORIGIN,
             ML_ORIGIN,
             InferTrainedModelDeploymentAction.INSTANCE,
             InferTrainedModelDeploymentAction.INSTANCE,
-            new InferTrainedModelDeploymentAction.Request(modelId, Collections.singletonList(doc)),
+            new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc)),
             ActionListener.wrap(
             ActionListener.wrap(
                 r -> listener.onResponse(r.getResults()),
                 r -> listener.onResponse(r.getResults()),
                 e -> listener.onResponse(new WarningInferenceResults(e.getMessage()))
                 e -> listener.onResponse(new WarningInferenceResults(e.getMessage()))

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
 import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
@@ -227,9 +228,12 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
         );
         );
     }
     }
 
 
-    public void infer(TrainedModelDeploymentTask task, Map<String, Object> doc, TimeValue timeout,
+    public void infer(TrainedModelDeploymentTask task,
+                      InferenceConfig config,
+                      Map<String, Object> doc,
+                      TimeValue timeout,
                       ActionListener<InferenceResults> listener) {
                       ActionListener<InferenceResults> listener) {
-        deploymentManager.infer(task, doc, timeout, listener);
+        deploymentManager.infer(task, config, doc, timeout, listener);
     }
     }
 
 
     public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {
     public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {

+ 16 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
@@ -131,6 +132,7 @@ public class DeploymentManager {
 
 
                 assert modelConfig.getInferenceConfig() instanceof NlpConfig;
                 assert modelConfig.getInferenceConfig() instanceof NlpConfig;
                 NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
                 NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
+                task.init(nlpConfig);
 
 
                 SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
                 SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
                 executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
                 executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
@@ -203,7 +205,9 @@ public class DeploymentManager {
     }
     }
 
 
     public void infer(TrainedModelDeploymentTask task,
     public void infer(TrainedModelDeploymentTask task,
-                      Map<String, Object> doc, TimeValue timeout,
+                      InferenceConfig config,
+                      Map<String, Object> doc,
+                      TimeValue timeout,
                       ActionListener<InferenceResults> listener) {
                       ActionListener<InferenceResults> listener) {
         if (task.isStopped()) {
         if (task.isStopped()) {
             listener.onFailure(
             listener.onFailure(
@@ -240,12 +244,20 @@ public class DeploymentManager {
                     List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
                     List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
                     NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
                     NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
                     processor.validateInputs(text);
                     processor.validateInputs(text);
-                    NlpTask.Request request = processor.getRequestBuilder().buildRequest(text, requestId);
+                    assert config instanceof NlpConfig;
+                    NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestId);
                     logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
                     logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
                     PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.registerRequest(requestId);
                     PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.registerRequest(requestId);
                     processContext.process.get().writeInferenceRequest(request.processInput);
                     processContext.process.get().writeInferenceRequest(request.processInput);
-                    waitForResult(processContext, pendingResult, request.tokenization, requestId, timeout, processor.getResultProcessor(),
-                        listener);
+                    waitForResult(
+                        processContext,
+                        pendingResult,
+                        request.tokenization,
+                        requestId,
+                        timeout,
+                        processor.getResultProcessor((NlpConfig) config),
+                        listener
+                    );
                 } catch (IOException e) {
                 } catch (IOException e) {
                     logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e);
                     logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e);
                     onFailure(ExceptionsHelper.serverError("error writing to process", e));
                     onFailure(ExceptionsHelper.serverError("error writing to process", e));

+ 26 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -18,6 +18,8 @@ import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
 
 
@@ -32,6 +34,7 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
     private final TrainedModelAllocationNodeService trainedModelAllocationNodeService;
     private final TrainedModelAllocationNodeService trainedModelAllocationNodeService;
     private volatile boolean stopped;
     private volatile boolean stopped;
     private final SetOnce<String> stoppedReason = new SetOnce<>();
     private final SetOnce<String> stoppedReason = new SetOnce<>();
+    private final SetOnce<InferenceConfig> inferenceConfig = new SetOnce<>();
 
 
     public TrainedModelDeploymentTask(
     public TrainedModelDeploymentTask(
         long id,
         long id,
@@ -50,6 +53,10 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         );
         );
     }
     }
 
 
+    void init(InferenceConfig inferenceConfig) {
+        this.inferenceConfig.set(inferenceConfig);
+    }
+
     public String getModelId() {
     public String getModelId() {
         return params.getModelId();
         return params.getModelId();
     }
     }
@@ -85,8 +92,25 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         stop(reason);
         stop(reason);
     }
     }
 
 
-    public void infer(Map<String, Object> doc, TimeValue timeout, ActionListener<InferenceResults> listener) {
-        trainedModelAllocationNodeService.infer(this, doc, timeout, listener);
+    public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {
+        if (inferenceConfig.get() == null) {
+            listener.onFailure(
+                ExceptionsHelper.badRequestException("[{}] inference not possible against uninitialized model", params.getModelId())
+            );
+            return;
+        }
+        if (update.isSupported(inferenceConfig.get()) == false) {
+            listener.onFailure(
+                ExceptionsHelper.badRequestException(
+                    "[{}] inference not possible. Task is configured with [{}] but received update of type [{}]",
+                    params.getModelId(),
+                    inferenceConfig.get().getName(),
+                    update.getName()
+                )
+            );
+            return;
+        }
+        trainedModelAllocationNodeService.infer(this, update.apply(inferenceConfig.get()), doc, timeout, listener);
     }
     }
 
 
     public Optional<ModelStats> modelStats() {
     public Optional<ModelStats> modelStats() {

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -37,6 +37,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
@@ -358,7 +360,11 @@ public class InferenceProcessor extends AbstractProcessor {
             } else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
             } else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
                 checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
                 checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
                 return RegressionConfigUpdate.fromMap(valueMap);
                 return RegressionConfigUpdate.fromMap(valueMap);
-            } else {
+            } else if (configMap.containsKey(ZeroShotClassificationConfig.NAME)) {
+                checkSupportedVersion(new ZeroShotClassificationConfig(List.of("unused"), null, null, null, null, null));
+                return ZeroShotClassificationConfigUpdate.fromMap(valueMap);
+            }
+            else {
                 throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
                 throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
                     configMap.keySet(),
                     configMap.keySet(),
                     Arrays.asList(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));
                     Arrays.asList(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));

+ 13 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.List;
 import java.util.List;
+import java.util.stream.Collectors;
 
 
 public class BertRequestBuilder implements NlpTask.RequestBuilder {
 public class BertRequestBuilder implements NlpTask.RequestBuilder {
 
 
@@ -37,7 +38,18 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
                 " token in its vocabulary");
                 " token in its vocabulary");
         }
         }
 
 
-        TokenizationResult tokenization = tokenizer.tokenize(inputs);
+        TokenizationResult tokenization = tokenizer.buildTokenizationResult(
+            inputs.stream().map(tokenizer::tokenize).collect(Collectors.toList())
+        );
+        return buildRequest(tokenization, requestId);
+    }
+
+    @Override
+    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
+        if (tokenizer.getPadToken().isEmpty()) {
+            throw new IllegalStateException("The input tokenizer does not have a " + BertTokenizer.PAD_TOKEN +
+                " token in its vocabulary");
+        }
         return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId));
         return new NlpTask.Request(tokenization, jsonRequest(tokenization, tokenizer.getPadToken().getAsInt(), requestId));
     }
     }
 
 

+ 6 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -10,12 +10,14 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.List;
 import java.util.List;
 
 
@@ -49,23 +51,23 @@ public class FillMaskProcessor implements NlpTask.Processor {
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
         return requestBuilder;
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return this::processResult;
         return this::processResult;
     }
     }
 
 
     InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
     InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
 
 
         if (tokenization.getTokenizations().isEmpty() ||
         if (tokenization.getTokenizations().isEmpty() ||
-            tokenization.getTokenizations().get(0).getTokens().isEmpty()) {
+            tokenization.getTokenizations().get(0).getTokens().length == 0) {
             return new FillMaskResults(Collections.emptyList());
             return new FillMaskResults(Collections.emptyList());
         }
         }
 
 
-        int maskTokenIndex = tokenization.getTokenizations().get(0).getTokens().indexOf(BertTokenizer.MASK_TOKEN);
+        int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf(BertTokenizer.MASK_TOKEN);
         // TODO - process all results in the batch
         // TODO - process all results in the batch
         double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
         double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
 
 

+ 8 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
@@ -124,12 +125,12 @@ public class NerProcessor implements NlpTask.Processor {
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
         return requestBuilder;
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return new NerResultProcessor(iobMap);
         return new NerResultProcessor(iobMap);
     }
     }
 
 
@@ -143,7 +144,7 @@ public class NerProcessor implements NlpTask.Processor {
         @Override
         @Override
         public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
         public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
             if (tokenization.getTokenizations().isEmpty() ||
             if (tokenization.getTokenizations().isEmpty() ||
-                tokenization.getTokenizations().get(0).getTokens().isEmpty()) {
+                tokenization.getTokenizations().get(0).getTokens().length == 0) {
                 return new NerResults(Collections.emptyList());
                 return new NerResults(Collections.emptyList());
             }
             }
             // TODO - process all results in the batch
             // TODO - process all results in the batch
@@ -171,7 +172,7 @@ public class NerProcessor implements NlpTask.Processor {
                                            IobTag[] iobMap) {
                                            IobTag[] iobMap) {
             List<TaggedToken> taggedTokens = new ArrayList<>();
             List<TaggedToken> taggedTokens = new ArrayList<>();
             int startTokenIndex = 0;
             int startTokenIndex = 0;
-            while (startTokenIndex < tokenization.getTokens().size()) {
+            while (startTokenIndex < tokenization.getTokens().length) {
                 int inputMapping = tokenization.getTokenMap()[startTokenIndex];
                 int inputMapping = tokenization.getTokenMap()[startTokenIndex];
                 if (inputMapping < 0) {
                 if (inputMapping < 0) {
                     // This token does not map to a token in the input (special tokens)
                     // This token does not map to a token in the input (special tokens)
@@ -179,14 +180,14 @@ public class NerProcessor implements NlpTask.Processor {
                     continue;
                     continue;
                 }
                 }
                 int endTokenIndex = startTokenIndex;
                 int endTokenIndex = startTokenIndex;
-                StringBuilder word = new StringBuilder(tokenization.getTokens().get(startTokenIndex));
-                while (endTokenIndex < tokenization.getTokens().size() - 1
+                StringBuilder word = new StringBuilder(tokenization.getTokens()[startTokenIndex]);
+                while (endTokenIndex < tokenization.getTokens().length - 1
                     && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) {
                     && tokenization.getTokenMap()[endTokenIndex + 1] == inputMapping) {
                     endTokenIndex++;
                     endTokenIndex++;
                     // TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens.
                     // TODO Here we try to get rid of the continuation hashes at the beginning of sub-tokens.
                     // It is probably more correct to implement detokenization on the tokenizer
                     // It is probably more correct to implement detokenization on the tokenizer
                     // that does reverse lookup based on token IDs.
                     // that does reverse lookup based on token IDs.
-                    String endTokenWord = tokenization.getTokens().get(endTokenIndex).substring(2);
+                    String endTokenWord = tokenization.getTokens()[endTokenIndex].substring(2);
                     word.append(endTokenWord);
                     word.append(endTokenWord);
                 }
                 }
                 double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);
                 double[] avgScores = Arrays.copyOf(scores[startTokenIndex], iobMap.length);

+ 4 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -56,6 +56,8 @@ public class NlpTask {
 
 
         Request buildRequest(List<String> inputs, String requestId) throws IOException;
         Request buildRequest(List<String> inputs, String requestId) throws IOException;
 
 
+        Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException;
+
         static void writePaddedTokens(String fieldName,
         static void writePaddedTokens(String fieldName,
                                       TokenizationResult tokenization,
                                       TokenizationResult tokenization,
                                       int padToken,
                                       int padToken,
@@ -97,10 +99,6 @@ public class NlpTask {
         InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult);
         InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult);
     }
     }
 
 
-    public interface ResultProcessorFactory {
-        ResultProcessor build(TokenizationResult tokenizationResult);
-    }
-
     public interface Processor {
     public interface Processor {
         /**
         /**
          * Validate the task input string.
          * Validate the task input string.
@@ -110,8 +108,8 @@ public class NlpTask {
          */
          */
         void validateInputs(List<String> inputs);
         void validateInputs(List<String> inputs);
 
 
-        RequestBuilder getRequestBuilder();
-        ResultProcessor getResultProcessor();
+        RequestBuilder getRequestBuilder(NlpConfig config);
+        ResultProcessor getResultProcessor(NlpConfig config);
     }
     }
 
 
     public static String extractInput(TrainedModelInput input, Map<String, Object> doc) {
     public static String extractInput(TrainedModelInput input, Map<String, Object> doc) {

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -34,12 +35,12 @@ public class PassThroughProcessor implements NlpTask.Processor {
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
         return requestBuilder;
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return PassThroughProcessor::processResult;
         return PassThroughProcessor::processResult;
     }
     }
 
 

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 
 
 import java.util.Locale;
 import java.util.Locale;
@@ -48,6 +49,12 @@ public enum TaskType {
         public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
         public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
             return new TextEmbeddingProcessor(tokenizer, (TextEmbeddingConfig) config);
             return new TextEmbeddingProcessor(tokenizer, (TextEmbeddingConfig) config);
         }
         }
+    },
+    ZERO_SHOT_CLASSIFICATION {
+        @Override
+        public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
+            return new ZeroShotClassificationProcessor(tokenizer, (ZeroShotClassificationConfig) config);
+        }
     };
     };
 
 
     public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {
     public NlpTask.Processor createProcessor(NlpTokenizer tokenizer, NlpConfig config) {

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -73,12 +74,12 @@ public class TextClassificationProcessor implements NlpTask.Processor {
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
         return requestBuilder;
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return this::processResult;
         return this::processResult;
     }
     }
 
 

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -33,12 +34,12 @@ public class TextEmbeddingProcessor implements NlpTask.Processor {
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.RequestBuilder getRequestBuilder() {
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
         return requestBuilder;
         return requestBuilder;
     }
     }
 
 
     @Override
     @Override
-    public NlpTask.ResultProcessor getResultProcessor() {
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
         return TextEmbeddingProcessor::processResult;
         return TextEmbeddingProcessor::processResult;
     }
     }
 
 

+ 194 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -0,0 +1,194 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.logging.LoggerMessageFormat;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class ZeroShotClassificationProcessor implements NlpTask.Processor {
+
+    private final NlpTokenizer tokenizer;
+    private final int entailmentPos;
+    private final int contraPos;
+    private final String[] labels;
+    private final String hypothesisTemplate;
+    private final boolean isMultiLabel;
+
+    ZeroShotClassificationProcessor(NlpTokenizer tokenizer, ZeroShotClassificationConfig config) {
+        this.tokenizer = tokenizer;
+        List<String> lowerCased = config.getClassificationLabels()
+            .stream()
+            .map(s -> s.toLowerCase(Locale.ROOT))
+            .collect(Collectors.toList());
+        this.entailmentPos = lowerCased.indexOf("entailment");
+        this.contraPos = lowerCased.indexOf("contradiction");
+        if (entailmentPos == -1 || contraPos == -1) {
+            throw ExceptionsHelper.badRequestException(
+                "zero_shot_classification requires [entailment] and [contradiction] in classification_labels"
+            );
+        }
+        this.labels = Optional.ofNullable(config.getLabels()).orElse(List.of()).toArray(String[]::new);
+        this.hypothesisTemplate = config.getHypothesisTemplate();
+        this.isMultiLabel = config.isMultiLabel();
+    }
+
+    @Override
+    public void validateInputs(List<String> inputs) {
+        // nothing to validate
+    }
+
+    @Override
+    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
+        final String[] labels;
+        if (nlpConfig instanceof ZeroShotClassificationConfig) {
+            ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig) nlpConfig;
+            labels = zeroShotConfig.getLabels().toArray(new String[0]);
+        } else {
+            labels = this.labels;
+        }
+        if (this.labels == null || this.labels.length == 0) {
+            throw ExceptionsHelper.badRequestException("zero_shot_classification requires non-empty [labels]");
+        }
+        return new RequestBuilder(tokenizer, labels, hypothesisTemplate);
+    }
+
+    @Override
+    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
+        final String[] labels;
+        final boolean isMultiLabel;
+        if (nlpConfig instanceof ZeroShotClassificationConfig) {
+            ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig) nlpConfig;
+            labels = zeroShotConfig.getLabels().toArray(new String[0]);
+            isMultiLabel = zeroShotConfig.isMultiLabel();
+        } else {
+            labels = this.labels;
+            isMultiLabel = this.isMultiLabel;
+        }
+        return new ResultProcessor(entailmentPos, contraPos, labels, isMultiLabel);
+    }
+
+    static class RequestBuilder implements NlpTask.RequestBuilder {
+
+        private final NlpTokenizer tokenizer;
+        private final String[] labels;
+        private final String hypothesisTemplate;
+
+        RequestBuilder(NlpTokenizer tokenizer, String[] labels, String hypothesisTemplate) {
+            this.tokenizer = tokenizer;
+            this.labels = labels;
+            this.hypothesisTemplate = hypothesisTemplate;
+        }
+
+        @Override
+        public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
+            if (inputs.size() > 1) {
+                throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time");
+            }
+            List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
+            for (String label : labels) {
+                tokenizations.add(tokenizer.tokenize(inputs.get(0), LoggerMessageFormat.format(null, hypothesisTemplate, label)));
+            }
+            TokenizationResult result = tokenizer.buildTokenizationResult(tokenizations);
+            return buildRequest(result, requestId);
+        }
+
+        @Override
+        public NlpTask.Request buildRequest(TokenizationResult tokenizationResult, String requestId) throws IOException {
+            return tokenizer.requestBuilder().buildRequest(tokenizationResult, requestId);
+        }
+    }
+
+    static class ResultProcessor implements NlpTask.ResultProcessor {
+        private final int entailmentPos;
+        private final int contraPos;
+        private final String[] labels;
+        private final boolean isMultiLabel;
+
+        ResultProcessor(int entailmentPos, int contraPos, String[] labels, boolean isMultiLabel) {
+            this.entailmentPos = entailmentPos;
+            this.contraPos = contraPos;
+            this.labels = labels;
+            this.isMultiLabel = isMultiLabel;
+        }
+
+        @Override
+        public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+            if (pyTorchResult.getInferenceResult().length < 1) {
+                return new WarningInferenceResults("Zero shot classification result has no data");
+            }
+            // TODO only the first entry in the batch result is verified and
+            // checked. Implement for all in batch
+            if (pyTorchResult.getInferenceResult()[0].length != labels.length) {
+                return new WarningInferenceResults(
+                    "Expected exactly [{}] values in zero shot classification result; got [{}]",
+                    labels.length,
+                    pyTorchResult.getInferenceResult().length
+                );
+            }
+            final double[] normalizedScores;
+            if (isMultiLabel) {
+                normalizedScores = new double[pyTorchResult.getInferenceResult()[0].length];
+                int v = 0;
+                for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
+                    if (vals.length != 3) {
+                        return new WarningInferenceResults(
+                            "Expected exactly [{}] values in inner zero shot classification result; got [{}]",
+                            3,
+                            vals.length
+                        );
+                    }
+                    // assume entailment is `0`, softmax between entailment and contradiction
+                    normalizedScores[v++] = NlpHelpers.convertToProbabilitiesBySoftMax(
+                        new double[]{vals[entailmentPos], vals[contraPos]}
+                    )[0];
+                }
+            } else {
+                double[] entailmentScores = new double[pyTorchResult.getInferenceResult()[0].length];
+                int v = 0;
+                for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
+                    if (vals.length != 3) {
+                        return new WarningInferenceResults(
+                            "Expected exactly [{}] values in inner zero shot classification result; got [{}]",
+                            3,
+                            vals.length
+                        );
+                    }
+                    entailmentScores[v++] = vals[entailmentPos];
+                }
+                normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(entailmentScores);
+            }
+
+            return new TextClassificationResults(
+                IntStream.range(0, normalizedScores.length)
+                    .mapToObj(i -> new TopClassEntry(labels[i], normalizedScores[i]))
+                    // Put the highest scoring class first
+                    .sorted(Comparator.comparing(TopClassEntry::getProbability).reversed())
+                    .limit(labels.length)
+                    .collect(Collectors.toList())
+            );
+        }
+    }
+}

+ 110 - 52
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
 
 
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
 import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
@@ -76,74 +77,63 @@ public class BertTokenizer implements NlpTokenizer {
         this.requestBuilder = requestBuilderFactory.apply(this);
         this.requestBuilder = requestBuilderFactory.apply(this);
     }
     }
 
 
+    @Override
+    public OptionalInt getPadToken() {
+        Integer pad = vocab.get(PAD_TOKEN);
+        if (pad != null) {
+            return OptionalInt.of(pad);
+        } else {
+            return OptionalInt.empty();
+        }
+    }
+
+    @Override
+    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations) {
+        TokenizationResult tokenizationResult = new TokenizationResult(originalVocab);
+        for (TokenizationResult.Tokenization tokenization : tokenizations) {
+            tokenizationResult.addTokenization(tokenization);
+        }
+        return tokenizationResult;
+    }
+
     /**
     /**
-     * Tokenize the list of inputs according to the basic tokenization
+     * Tokenize the input according to the basic tokenization
      * options then perform Word Piece tokenization with the given vocabulary.
      * options then perform Word Piece tokenization with the given vocabulary.
      *
      *
      * The result is the Word Piece tokens, a map of the Word Piece
      * The result is the Word Piece tokens, a map of the Word Piece
      * token position to the position of the token in the source for
      * token position to the position of the token in the source for
      * each input string grouped into a {@link Tokenization}.
      * each input string grouped into a {@link Tokenization}.
      *
      *
-     * @param text Text to tokenize
+     * @param seq Text to tokenize
      * @return A {@link Tokenization}
      * @return A {@link Tokenization}
      */
      */
     @Override
     @Override
-    public TokenizationResult tokenize(List<String> text) {
-        TokenizationResult tokenization = new TokenizationResult(originalVocab);
-
-        for (String input: text) {
-            addTokenization(tokenization, input);
-        }
-        return tokenization;
-    }
-
-
-    private void addTokenization(TokenizationResult tokenization, String text) {
-        BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
-
-        List<String> delineatedTokens = basicTokenizer.tokenize(text);
-        List<WordPieceTokenizer.TokenAndId> wordPieceTokens = new ArrayList<>();
-        List<Integer> tokenPositionMap = new ArrayList<>();
-        if (withSpecialTokens) {
-            // insert the first token to simplify the loop counter logic later
-            tokenPositionMap.add(SPECIAL_TOKEN_POSITION);
-        }
-
-        for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) {
-            String token = delineatedTokens.get(sourceIndex);
-            if (neverSplit.contains(token)) {
-                wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN))));
-                tokenPositionMap.add(sourceIndex);
-            } else {
-                List<WordPieceTokenizer.TokenAndId> tokens = wordPieceTokenizer.tokenize(token);
-                for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) {
-                    tokenPositionMap.add(sourceIndex);
-                }
-                wordPieceTokens.addAll(tokens);
-            }
-        }
-
+    public TokenizationResult.Tokenization tokenize(String seq) {
+        var innerResult = innerTokenize(seq);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokens = innerResult.v1();
+        List<Integer> tokenPositionMap = innerResult.v2();
         int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
         int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
-        List<String> tokens = new ArrayList<>(numTokens);
-        int [] tokenIds = new int[numTokens];
-        int [] tokenMap = new int[numTokens];
+        String[] tokens = new String[numTokens];
+        int[] tokenIds = new int[numTokens];
+        int[] tokenMap = new int[numTokens];
 
 
         if (withSpecialTokens) {
         if (withSpecialTokens) {
-            tokens.add(CLASS_TOKEN);
+            tokens[0] = CLASS_TOKEN;
             tokenIds[0] = vocab.get(CLASS_TOKEN);
             tokenIds[0] = vocab.get(CLASS_TOKEN);
             tokenMap[0] = SPECIAL_TOKEN_POSITION;
             tokenMap[0] = SPECIAL_TOKEN_POSITION;
         }
         }
 
 
         int i = withSpecialTokens ? 1 : 0;
         int i = withSpecialTokens ? 1 : 0;
+        final int decrementHandler = withSpecialTokens ? 1 : 0;
         for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) {
         for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) {
-            tokens.add(tokenAndId.getToken());
+            tokens[i] = tokenAndId.getToken();
             tokenIds[i] = tokenAndId.getId();
             tokenIds[i] = tokenAndId.getId();
-            tokenMap[i] = tokenPositionMap.get(i);
+            tokenMap[i] = tokenPositionMap.get(i-decrementHandler);
             i++;
             i++;
         }
         }
 
 
         if (withSpecialTokens) {
         if (withSpecialTokens) {
-            tokens.add(SEPARATOR_TOKEN);
+            tokens[i] = SEPARATOR_TOKEN;
             tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
             tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
             tokenMap[i] = SPECIAL_TOKEN_POSITION;
             tokenMap[i] = SPECIAL_TOKEN_POSITION;
         }
         }
@@ -155,18 +145,86 @@ public class BertTokenizer implements NlpTokenizer {
                 maxSequenceLength
                 maxSequenceLength
             );
             );
         }
         }
-
-        tokenization.addTokenization(text, tokens, tokenIds, tokenMap);
+        return new TokenizationResult.Tokenization(seq, tokens, tokenIds, tokenMap);
     }
     }
 
 
     @Override
     @Override
-    public OptionalInt getPadToken() {
-        Integer pad = vocab.get(PAD_TOKEN);
-        if (pad != null) {
-            return OptionalInt.of(pad);
-        } else {
-            return OptionalInt.empty();
+    public TokenizationResult.Tokenization tokenize(String seq1, String seq2) {
+        var innerResult = innerTokenize(seq1);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokenSeq1s = innerResult.v1();
+        List<Integer> tokenPositionMapSeq1 = innerResult.v2();
+        innerResult = innerTokenize(seq2);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokenSeq2s = innerResult.v1();
+        List<Integer> tokenPositionMapSeq2 = innerResult.v2();
+        if (withSpecialTokens == false)  {
+            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
+        }
+        // [CLS] seq1 [SEP] seq2 [SEP]
+        int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3;
+        String[] tokens = new String[numTokens];
+        int[] tokenIds = new int[numTokens];
+        int[] tokenMap = new int[numTokens];
+
+        tokens[0] = CLASS_TOKEN;
+        tokenIds[0] = vocab.get(CLASS_TOKEN);
+        tokenMap[0] = SPECIAL_TOKEN_POSITION;
+
+        int i = 1;
+        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq1s) {
+            tokens[i] = tokenAndId.getToken();
+            tokenIds[i] = tokenAndId.getId();
+            tokenMap[i] = tokenPositionMapSeq1.get(i - 1);
+            i++;
+        }
+        tokens[i] = SEPARATOR_TOKEN;
+        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
+        tokenMap[i] = SPECIAL_TOKEN_POSITION;
+        ++i;
+
+        int j = 0;
+        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq2s) {
+            tokens[i] = tokenAndId.getToken();
+            tokenIds[i] = tokenAndId.getId();
+            tokenMap[i] = tokenPositionMapSeq2.get(j);
+            i++;
+            j++;
+        }
+
+        tokens[i] = SEPARATOR_TOKEN;
+        tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
+        tokenMap[i] = SPECIAL_TOKEN_POSITION;
+
+        // TODO handle seq1 truncation
+        if (tokenIds.length > maxSequenceLength) {
+            throw ExceptionsHelper.badRequestException(
+                "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
+                tokenIds.length,
+                maxSequenceLength
+            );
+        }
+        return new TokenizationResult.Tokenization(seq1 + seq2, tokens, tokenIds, tokenMap);
+    }
+
+    private Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerTokenize(String seq) {
+        BasicTokenizer basicTokenizer = new BasicTokenizer(doLowerCase, doTokenizeCjKChars, doStripAccents, neverSplit);
+        List<String> delineatedTokens = basicTokenizer.tokenize(seq);
+        List<WordPieceTokenizer.TokenAndId> wordPieceTokens = new ArrayList<>();
+        List<Integer> tokenPositionMap = new ArrayList<>();
+
+        for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); sourceIndex++) {
+            String token = delineatedTokens.get(sourceIndex);
+            if (neverSplit.contains(token)) {
+                wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, vocab.getOrDefault(token, vocab.get(UNKNOWN_TOKEN))));
+                tokenPositionMap.add(sourceIndex);
+            } else {
+                List<WordPieceTokenizer.TokenAndId> tokens = wordPieceTokenizer.tokenize(token);
+                for (int tokenCount = 0; tokenCount < tokens.size(); tokenCount++) {
+                    tokenPositionMap.add(sourceIndex);
+                }
+                wordPieceTokens.addAll(tokens);
+            }
         }
         }
+        return Tuple.tuple(wordPieceTokens, tokenPositionMap);
     }
     }
 
 
     @Override
     @Override

+ 5 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

@@ -22,7 +22,11 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.V
 
 
 public interface NlpTokenizer {
 public interface NlpTokenizer {
 
 
-    TokenizationResult tokenize(List<String> text);
+    TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations);
+
+    TokenizationResult.Tokenization tokenize(String seq);
+
+    TokenizationResult.Tokenization tokenize(String seq1, String seq2);
 
 
     NlpTask.RequestBuilder requestBuilder();
     NlpTask.RequestBuilder requestBuilder();
 
 

+ 13 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -29,26 +29,31 @@ public class TokenizationResult {
         return tokenizations;
         return tokenizations;
     }
     }
 
 
-    public void addTokenization(String input, List<String> tokens, int[] tokenIds, int[] tokenMap) {
+    public void addTokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) {
         maxLength = Math.max(maxLength, tokenIds.length);
         maxLength = Math.max(maxLength, tokenIds.length);
         tokenizations.add(new Tokenization(input, tokens, tokenIds, tokenMap));
         tokenizations.add(new Tokenization(input, tokens, tokenIds, tokenMap));
     }
     }
 
 
+    public void addTokenization(Tokenization tokenization) {
+        maxLength = Math.max(maxLength, tokenization.tokenIds.length);
+        tokenizations.add(tokenization);
+    }
+
     public int getLongestSequenceLength() {
     public int getLongestSequenceLength() {
         return maxLength;
         return maxLength;
     }
     }
 
 
     public static class Tokenization {
     public static class Tokenization {
 
 
-        String input;
-        private final List<String> tokens;
+        private final String inputSeqs;
+        private final String[] tokens;
         private final int[] tokenIds;
         private final int[] tokenIds;
         private final int[] tokenMap;
         private final int[] tokenMap;
 
 
-        public Tokenization(String input, List<String> tokens, int[] tokenIds, int[] tokenMap) {
-            assert tokens.size() == tokenIds.length;
+        public Tokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) {
+            assert tokens.length == tokenIds.length;
             assert tokenIds.length == tokenMap.length;
             assert tokenIds.length == tokenMap.length;
-            this.input = input;
+            this.inputSeqs = input;
             this.tokens = tokens;
             this.tokens = tokens;
             this.tokenIds = tokenIds;
             this.tokenIds = tokenIds;
             this.tokenMap = tokenMap;
             this.tokenMap = tokenMap;
@@ -59,7 +64,7 @@ public class TokenizationResult {
          *
          *
          * @return A list of tokens
          * @return A list of tokens
          */
          */
-        public List<String> getTokens() {
+        public String[] getTokens() {
             return tokens;
             return tokens;
         }
         }
 
 
@@ -84,7 +89,7 @@ public class TokenizationResult {
         }
         }
 
 
         public String getInput() {
         public String getInput() {
-            return input;
+            return inputSeqs;
         }
         }
     }
     }
 }
 }

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -41,7 +41,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris";
         String input = "The capital of " + BertTokenizer.MASK_TOKEN + " is Paris";
 
 
         List<String> vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France");
         List<String> vocab = Arrays.asList("The", "capital", "of", BertTokenizer.MASK_TOKEN, "is", "Paris", "France");
-        List<String> tokens = Arrays.asList(input.split(" "));
+        String[] tokens = input.split(" ");
         int[] tokenMap = new int[] {0, 1, 2, 3, 4, 5};
         int[] tokenMap = new int[] {0, 1, 2, 3, 4, 5};
         int[] tokenIds = new int[] {0, 1, 2, 3, 4, 5};
         int[] tokenIds = new int[] {0, 1, 2, 3, 4, 5};
 
 
@@ -68,7 +68,7 @@ public class FillMaskProcessorTests extends ESTestCase {
 
 
     public void testProcessResults_GivenMissingTokens() {
     public void testProcessResults_GivenMissingTokens() {
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
-        tokenization.addTokenization("", Collections.emptyList(), new int[] {}, new int[] {});
+        tokenization.addTokenization("", new String[]{}, new int[] {}, new int[] {});
 
 
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
         FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -222,6 +222,6 @@ public class NerProcessorTests extends ESTestCase {
             vocab,
             vocab,
             new BertTokenization(true, false, null)
             new BertTokenization(true, false, null)
         ).setDoLowerCase(true).setWithSpecialTokens(false).build();
         ).setDoLowerCase(true).setWithSpecialTokens(false).build();
-        return tokenizer.tokenize(List.of(input));
+        return tokenizer.buildTokenizationResult(List.of(tokenizer.tokenize(input)));
     }
     }
 }
 }

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -65,7 +65,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
         TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
         TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
 
 
-        NlpTask.Request request = processor.getRequestBuilder().buildRequest(List.of("Elasticsearch fun"), "request1");
+        NlpTask.Request request = processor.getRequestBuilder(config).buildRequest(List.of("Elasticsearch fun"), "request1");
 
 
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
         Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
 

+ 78 - 34
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizerTests.java

@@ -15,7 +15,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.List;
 import java.util.List;
 
 
-import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.arrayContaining;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasSize;
 
 
 public class BertTokenizerTests extends ESTestCase {
 public class BertTokenizerTests extends ESTestCase {
@@ -26,9 +26,8 @@ public class BertTokenizerTests extends ESTestCase {
             new BertTokenization(null, false, null)
             new BertTokenization(null, false, null)
         ).build();
         ).build();
 
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", "fun"));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", "fun"));
         assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
     }
     }
@@ -39,9 +38,8 @@ public class BertTokenizerTests extends ESTestCase {
             Tokenization.createDefault()
             Tokenization.createDefault()
         ).build();
         ).build();
 
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("elasticsearch fun"));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("[CLS]", "elastic", "##search", "fun", "[SEP]"));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("elasticsearch fun");
+        assertThat(tokenization.getTokens(), arrayContaining("[CLS]", "elastic", "##search", "fun", "[SEP]"));
         assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {3, 0, 1, 2, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {-1, 0, 0, 1, -1}, tokenization.getTokenMap());
     }
     }
@@ -56,9 +54,8 @@ public class BertTokenizerTests extends ESTestCase {
          .setWithSpecialTokens(false)
          .setWithSpecialTokens(false)
          .build();
          .build();
 
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch " + specialToken + " fun"));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", specialToken, "fun"));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch " + specialToken + " fun");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", specialToken, "fun"));
         assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1, 3, 2}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0, 1, 2}, tokenization.getTokenMap());
     }
     }
@@ -72,15 +69,13 @@ public class BertTokenizerTests extends ESTestCase {
              .setWithSpecialTokens(false)
              .setWithSpecialTokens(false)
              .build();
              .build();
 
 
-            TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
-            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-            assertThat(tokenization.getTokens(), contains(BertTokenizer.UNKNOWN_TOKEN, "fun"));
+            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+            assertThat(tokenization.getTokens(), arrayContaining(BertTokenizer.UNKNOWN_TOKEN, "fun"));
             assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
             assertArrayEquals(new int[] {3, 2}, tokenization.getTokenIds());
             assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
             assertArrayEquals(new int[] {0, 1}, tokenization.getTokenMap());
 
 
-            tr = tokenizer.tokenize(List.of("elasticsearch fun"));
-            tokenization = tr.getTokenizations().get(0);
-            assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
+            tokenization = tokenizer.tokenize("elasticsearch fun");
+            assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun"));
         }
         }
 
 
         {
         {
@@ -89,9 +84,8 @@ public class BertTokenizerTests extends ESTestCase {
                 .setWithSpecialTokens(false)
                 .setWithSpecialTokens(false)
                 .build();
                 .build();
 
 
-            TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch fun"));
-            TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-            assertThat(tokenization.getTokens(), contains("elastic", "##search", "fun"));
+            TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch fun");
+            assertThat(tokenization.getTokens(), arrayContaining("elastic", "##search", "fun"));
         }
         }
     }
     }
 
 
@@ -101,15 +95,13 @@ public class BertTokenizerTests extends ESTestCase {
             Tokenization.createDefault()
             Tokenization.createDefault()
         ).setWithSpecialTokens(false).build();
         ).setWithSpecialTokens(false).build();
 
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch, fun."));
-        TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "."));
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch, fun.");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1, 4, 2, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3}, tokenization.getTokenMap());
 
 
-        tr = tokenizer.tokenize(List.of("Elasticsearch, fun [MASK]."));
-        tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search", ",", "fun", "[MASK]", "."));
+        tokenization = tokenizer.tokenize("Elasticsearch, fun [MASK].");
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search", ",", "fun", "[MASK]", "."));
         assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1, 4, 2, 5, 3}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
     }
     }
@@ -124,31 +116,83 @@ public class BertTokenizerTests extends ESTestCase {
             new BertTokenization(null, false, null)
             new BertTokenization(null, false, null)
         ).build();
         ).build();
 
 
-        TokenizationResult tr = tokenizer.tokenize(List.of("Elasticsearch",
-            "my little red car",
-            "Godzilla day",
-            "Godzilla Pancake red car day"
-            ));
+        TokenizationResult tr = tokenizer.buildTokenizationResult(
+            List.of(
+                tokenizer.tokenize("Elasticsearch"),
+                tokenizer.tokenize("my little red car"),
+                tokenizer.tokenize("Godzilla day"),
+                tokenizer.tokenize("Godzilla Pancake red car day")
+            )
+        );
         assertThat(tr.getTokenizations(), hasSize(4));
         assertThat(tr.getTokenizations(), hasSize(4));
 
 
         TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
         TokenizationResult.Tokenization tokenization = tr.getTokenizations().get(0);
-        assertThat(tokenization.getTokens(), contains("Elastic", "##search"));
+        assertThat(tokenization.getTokens(), arrayContaining("Elastic", "##search"));
         assertArrayEquals(new int[] {0, 1}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0}, tokenization.getTokenMap());
 
 
         tokenization = tr.getTokenizations().get(1);
         tokenization = tr.getTokenizations().get(1);
-        assertThat(tokenization.getTokens(), contains("my", "little", "red", "car"));
+        assertThat(tokenization.getTokens(), arrayContaining("my", "little", "red", "car"));
         assertArrayEquals(new int[] {5, 6, 7, 8}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {5, 6, 7, 8}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 1, 2, 3}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 1, 2, 3}, tokenization.getTokenMap());
 
 
         tokenization = tr.getTokenizations().get(2);
         tokenization = tr.getTokenizations().get(2);
-        assertThat(tokenization.getTokens(), contains("God", "##zilla", "day"));
+        assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "day"));
         assertArrayEquals(new int[] {9, 10, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {9, 10, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0, 1}, tokenization.getTokenMap());
 
 
         tokenization = tr.getTokenizations().get(3);
         tokenization = tr.getTokenizations().get(3);
-        assertThat(tokenization.getTokens(), contains("God", "##zilla", "Pancake", "red", "car", "day"));
+        assertThat(tokenization.getTokens(), arrayContaining("God", "##zilla", "Pancake", "red", "car", "day"));
         assertArrayEquals(new int[] {9, 10, 3, 7, 8, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {9, 10, 3, 7, 8, 4}, tokenization.getTokenIds());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
         assertArrayEquals(new int[] {0, 0, 1, 2, 3, 4}, tokenization.getTokenMap());
     }
     }
+
+    public void testMultiSeqTokenization() {
+        List<String> vocab = List.of(
+            "Elastic",
+            "##search",
+            "is",
+            "fun",
+            "my",
+            "little",
+            "red",
+            "car",
+            "God",
+            "##zilla",
+            BertTokenizer.CLASS_TOKEN,
+            BertTokenizer.SEPARATOR_TOKEN
+        );
+        BertTokenizer tokenizer = BertTokenizer.builder(vocab, Tokenization.createDefault())
+            .setDoLowerCase(false)
+            .setWithSpecialTokens(true)
+            .build();
+        TokenizationResult.Tokenization tokenization = tokenizer.tokenize("Elasticsearch is fun", "Godzilla my little red car");
+        assertThat(
+            tokenization.getTokens(),
+            arrayContaining(
+                BertTokenizer.CLASS_TOKEN,
+                "Elastic",
+                "##search",
+                "is",
+                "fun",
+                BertTokenizer.SEPARATOR_TOKEN,
+                "God",
+                "##zilla",
+                "my",
+                "little",
+                "red",
+                "car",
+                BertTokenizer.SEPARATOR_TOKEN
+            )
+        );
+        assertArrayEquals(new int[] { 10, 0, 1, 2, 3, 11, 8, 9, 4, 5, 6, 7, 11 }, tokenization.getTokenIds());
+    }
+
+    public void testMultiSeqRequiresSpecialTokens() {
+        BertTokenizer tokenizer = BertTokenizer.builder(List.of("foo"), Tokenization.createDefault())
+            .setDoLowerCase(false)
+            .setWithSpecialTokens(false)
+            .build();
+        expectThrows(Exception.class, () -> tokenizer.tokenize("foo", "foo"));
+    }
 }
 }