Browse Source

[ML] Add Inference time configuration overrides (#78441)

Allow inference settings to be updated at the point of inference
Updating the results_field is added for all task types, additionally 
fill_mask tasks can set the number of results (previously hard-coded 
to 5) and text_classification can set the classification labels 
(e.g. happy/sad)
David Kyle 4 years ago
parent
commit
38ecd0baa1
41 changed files with 1797 additions and 165 deletions
  1. 31 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  2. 1 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java
  3. 84 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java
  4. 179 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdate.java
  5. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java
  6. 20 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java
  7. 153 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdate.java
  8. 3 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java
  9. 0 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java
  10. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java
  11. 21 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java
  12. 151 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdate.java
  13. 1 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java
  14. 97 24
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java
  15. 211 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdate.java
  16. 20 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfig.java
  17. 152 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java
  18. 20 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java
  19. 27 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java
  20. 3 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java
  21. 103 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java
  22. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java
  23. 82 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java
  24. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java
  25. 80 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java
  26. 28 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java
  27. 148 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java
  28. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java
  29. 80 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java
  30. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java
  31. 27 5
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java
  32. 1 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java
  33. 2 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java
  34. 2 1
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java
  35. 27 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  36. 8 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  37. 2 28
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  38. 9 8
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  39. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  40. 6 32
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  41. 1 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java

+ 31 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -28,12 +28,12 @@ import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResu
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
@@ -41,14 +41,19 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInf
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
@@ -192,12 +197,22 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             new ParseField(ZeroShotClassificationConfig.NAME),
             ZeroShotClassificationConfig::fromXContentStrict));
 
+        // Inference Configs Update
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
             ClassificationConfigUpdate::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, new ParseField(FillMaskConfigUpdate.NAME),
+            FillMaskConfigUpdate::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, new ParseField(NerConfigUpdate.NAME),
+            NerConfigUpdate::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, new ParseField(PassThroughConfigUpdate.NAME),
+            PassThroughConfigUpdate::fromXContentStrict));
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
             RegressionConfigUpdate::fromXContentStrict));
-        namedXContent.add(
-            new NamedXContentRegistry.Entry(
+        namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, new ParseField(TextClassificationConfig.NAME),
+            TextClassificationConfigUpdate::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, new ParseField(TextEmbeddingConfigUpdate.NAME),
+            TextEmbeddingConfigUpdate::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(
                 InferenceConfigUpdate.class,
                 new ParseField(ZeroShotClassificationConfigUpdate.NAME),
                 ZeroShotClassificationConfigUpdate::fromXContentStrict
@@ -305,14 +320,25 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
             ZeroShotClassificationConfig.NAME, ZeroShotClassificationConfig::new));
 
+        // Inference Configs Updates
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            EmptyConfigUpdate.NAME, EmptyConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            FillMaskConfigUpdate.NAME, FillMaskConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            NerConfigUpdate.NAME, NerConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            PassThroughConfigUpdate.NAME, PassThroughConfigUpdate::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             RegressionConfigUpdate.NAME.getPreferredName(), RegressionConfigUpdate::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             ResultsFieldUpdate.NAME, ResultsFieldUpdate::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
-            EmptyConfigUpdate.NAME, EmptyConfigUpdate::new));
+            TextClassificationConfigUpdate.NAME, TextClassificationConfigUpdate::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
+            TextEmbeddingConfigUpdate.NAME, TextClassificationConfigUpdate::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
             ZeroShotClassificationConfigUpdate.NAME, ZeroShotClassificationConfigUpdate::new));
 

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java

@@ -110,6 +110,7 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str
         return topClassesResultsField;
     }
 
+    @Override
     public String getResultsField() {
         return resultsField;
     }

+ 84 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfig.java

@@ -10,7 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
@@ -25,23 +25,23 @@ import java.util.Optional;
 public class FillMaskConfig implements NlpConfig {
 
     public static final String NAME = "fill_mask";
+    public static final int DEFAULT_NUM_RESULTS = 5;
 
     public static FillMaskConfig fromXContentStrict(XContentParser parser) {
-        return STRICT_PARSER.apply(parser, null);
+        return STRICT_PARSER.apply(parser, null).build();
     }
 
     public static FillMaskConfig fromXContentLenient(XContentParser parser) {
-        return LENIENT_PARSER.apply(parser, null);
+        return LENIENT_PARSER.apply(parser, null).build();
     }
 
-    private static final ConstructingObjectParser<FillMaskConfig, Void> STRICT_PARSER = createParser(false);
-    private static final ConstructingObjectParser<FillMaskConfig, Void> LENIENT_PARSER = createParser(true);
+    private static final ObjectParser<FillMaskConfig.Builder, Void> STRICT_PARSER = createParser(false);
+    private static final ObjectParser<FillMaskConfig.Builder, Void> LENIENT_PARSER = createParser(true);
 
-    private static ConstructingObjectParser<FillMaskConfig, Void> createParser(boolean ignoreUnknownFields) {
-        ConstructingObjectParser<FillMaskConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new FillMaskConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
+    private static ObjectParser<FillMaskConfig.Builder, Void> createParser(boolean ignoreUnknownFields) {
+        ObjectParser<FillMaskConfig.Builder, Void> parser = new ObjectParser<>(NAME, ignoreUnknownFields, Builder::new);
         parser.declareObject(
-            ConstructingObjectParser.optionalConstructorArg(),
+            Builder::setVocabularyConfig,
             (p, c) -> {
                 if (ignoreUnknownFields == false) {
                     throw ExceptionsHelper.badRequestException(
@@ -54,24 +54,35 @@ public class FillMaskConfig implements NlpConfig {
             VOCABULARY
         );
         parser.declareNamedObject(
-            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+            Builder::setTokenization, (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
                 TOKENIZATION
         );
+        parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        parser.declareString(Builder::setResultsField, RESULTS_FIELD);
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
     private final Tokenization tokenization;
+    private final int numTopClasses;
+    private final String resultsField;
 
-    public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
+    public FillMaskConfig(@Nullable VocabularyConfig vocabularyConfig,
+                          @Nullable Tokenization tokenization,
+                          @Nullable Integer numTopClasses,
+                          @Nullable String resultsField) {
         this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
             .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
+        this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_RESULTS : numTopClasses;
+        this.resultsField = resultsField;
     }
 
     public FillMaskConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
         tokenization = in.readNamedWriteable(Tokenization.class);
+        numTopClasses = in.readInt();
+        resultsField = in.readOptionalString();
     }
 
     @Override
@@ -79,6 +90,10 @@ public class FillMaskConfig implements NlpConfig {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
         NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
+        builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -92,6 +107,8 @@ public class FillMaskConfig implements NlpConfig {
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
         out.writeNamedWriteable(tokenization);
+        out.writeInt(numTopClasses);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -116,12 +133,14 @@ public class FillMaskConfig implements NlpConfig {
 
         FillMaskConfig that = (FillMaskConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenization, that.tokenization);
+            && Objects.equals(tokenization, that.tokenization)
+            && Objects.equals(resultsField, that.resultsField)
+            && numTopClasses == that.numTopClasses;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization);
+        return Objects.hash(vocabularyConfig, tokenization, numTopClasses, resultsField);
     }
 
     @Override
@@ -134,8 +153,60 @@ public class FillMaskConfig implements NlpConfig {
         return tokenization;
     }
 
+    public int getNumTopClasses() {
+        return numTopClasses;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
     @Override
     public boolean isAllocateOnly() {
         return true;
     }
+
+    public static class Builder {
+        private VocabularyConfig vocabularyConfig;
+        private Tokenization tokenization;
+        private int numTopClasses;
+        private String resultsField;
+
+        Builder() {}
+
+        Builder(FillMaskConfig config) {
+            this.vocabularyConfig = config.vocabularyConfig;
+            this.tokenization = config.tokenization;
+            this.numTopClasses = config.numTopClasses;
+            this.resultsField = config.resultsField;
+        }
+
+        public FillMaskConfig.Builder setVocabularyConfig(VocabularyConfig vocabularyConfig) {
+            this.vocabularyConfig = vocabularyConfig;
+            return this;
+        }
+
+        public FillMaskConfig.Builder setTokenization(Tokenization tokenization) {
+            this.tokenization = tokenization;
+            return this;
+        }
+
+        public FillMaskConfig.Builder setNumTopClasses(Integer numTopClasses) {
+            this.numTopClasses = numTopClasses;
+            return this;
+        }
+
+        public FillMaskConfig.Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public FillMaskConfig build() {
+            return new FillMaskConfig(vocabularyConfig,
+                tokenization,
+                numTopClasses,
+                resultsField);
+        }
+    }
 }

+ 179 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdate.java

@@ -0,0 +1,179 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.NUM_TOP_CLASSES;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+
+public class FillMaskConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+
+    public static final String NAME = FillMaskConfig.NAME;
+
+    public static FillMaskConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
+        String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
+        }
+        return new FillMaskConfigUpdate(numTopClasses, resultsField);
+    }
+
+    private static final ObjectParser<FillMaskConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<FillMaskConfigUpdate.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<FillMaskConfigUpdate.Builder, Void> parser = new ObjectParser<>(
+            NAME,
+            lenient,
+            FillMaskConfigUpdate.Builder::new);
+        parser.declareString(FillMaskConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        parser.declareInt(FillMaskConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        return parser;
+    }
+
+    public static FillMaskConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    private final Integer numTopClasses;
+    private final String resultsField;
+
+    public FillMaskConfigUpdate(Integer numTopClasses, String resultsField) {
+        this.numTopClasses = numTopClasses;
+        this.resultsField = resultsField;
+    }
+
+    public FillMaskConfigUpdate(StreamInput in) throws IOException {
+        this.numTopClasses = in.readOptionalInt();
+        this.resultsField = in.readOptionalString();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalInt(numTopClasses);
+        out.writeOptionalString(resultsField);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (numTopClasses != null) {
+            builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
+        }
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (originalConfig instanceof FillMaskConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        FillMaskConfig fillMaskConfig = (FillMaskConfig)originalConfig;
+        if (isNoop(fillMaskConfig)) {
+            return originalConfig;
+        }
+
+        FillMaskConfig.Builder builder = new FillMaskConfig.Builder(fillMaskConfig);
+        if (numTopClasses != null) {
+            builder.setNumTopClasses(numTopClasses);
+        }
+        if (resultsField != null) {
+            builder.setResultsField(resultsField);
+        }
+        return builder.build();
+    }
+
+    boolean isNoop(FillMaskConfig originalConfig) {
+        return (this.numTopClasses == null || this.numTopClasses == originalConfig.getNumTopClasses()) &&
+            (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()));
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof FillMaskConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder()
+            .setNumTopClasses(numTopClasses)
+            .setResultsField(resultsField);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        FillMaskConfigUpdate that = (FillMaskConfigUpdate) o;
+        return Objects.equals(numTopClasses, that.numTopClasses) &&
+            Objects.equals(resultsField, that.resultsField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(numTopClasses, resultsField);
+    }
+
+    public static class Builder
+        implements InferenceConfigUpdate.Builder<FillMaskConfigUpdate.Builder, FillMaskConfigUpdate> {
+        private Integer numTopClasses;
+        private String resultsField;
+
+        public FillMaskConfigUpdate.Builder setNumTopClasses(Integer numTopClasses) {
+            this.numTopClasses = numTopClasses;
+            return this;
+        }
+
+        @Override
+        public FillMaskConfigUpdate.Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public FillMaskConfigUpdate build() {
+            return new FillMaskConfigUpdate(this.numTopClasses, this.resultsField);
+        }
+    }
+}

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

@@ -24,5 +24,7 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
         return false;
     }
 
+    String getResultsField();
+
     boolean isAllocateOnly();
 }

+ 20 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfig.java

@@ -42,7 +42,7 @@ public class NerConfig implements NlpConfig {
     @SuppressWarnings({ "unchecked"})
     private static ConstructingObjectParser<NerConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<NerConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2]));
+            a -> new NerConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (String) a[3]));
         parser.declareObject(
             ConstructingObjectParser.optionalConstructorArg(),
             (p, c) -> {
@@ -61,26 +61,32 @@ public class NerConfig implements NlpConfig {
                 TOKENIZATION
         );
         parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
+
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
     private final Tokenization tokenization;
     private final List<String> classificationLabels;
+    private final String resultsField;
 
     public NerConfig(@Nullable VocabularyConfig vocabularyConfig,
                      @Nullable Tokenization tokenization,
-                     @Nullable List<String> classificationLabels) {
+                     @Nullable List<String> classificationLabels,
+                     @Nullable String resultsField) {
         this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
             .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
         this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
+        this.resultsField = resultsField;
     }
 
     public NerConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
         tokenization = in.readNamedWriteable(Tokenization.class);
         classificationLabels = in.readStringList();
+        resultsField = in.readOptionalString();
     }
 
     @Override
@@ -88,6 +94,7 @@ public class NerConfig implements NlpConfig {
         vocabularyConfig.writeTo(out);
         out.writeNamedWriteable(tokenization);
         out.writeStringCollection(classificationLabels);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -98,6 +105,9 @@ public class NerConfig implements NlpConfig {
         if (classificationLabels.isEmpty() == false) {
             builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
         }
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -130,12 +140,13 @@ public class NerConfig implements NlpConfig {
         NerConfig that = (NerConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
             && Objects.equals(tokenization, that.tokenization)
-            && Objects.equals(classificationLabels, that.classificationLabels);
+            && Objects.equals(classificationLabels, that.classificationLabels)
+            && Objects.equals(resultsField, that.resultsField);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization, classificationLabels);
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, resultsField);
     }
 
     @Override
@@ -152,6 +163,11 @@ public class NerConfig implements NlpConfig {
         return classificationLabels;
     }
 
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
     @Override
     public boolean isAllocateOnly() {
         return true;

+ 153 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdate.java

@@ -0,0 +1,153 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+
+public class NerConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+    public static final String NAME = NerConfig.NAME;
+
+    public static NerConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
+        }
+        return new NerConfigUpdate(resultsField);
+    }
+
+    private static final ObjectParser<NerConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<NerConfigUpdate.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<NerConfigUpdate.Builder, Void> parser = new ObjectParser<>(
+            NAME,
+            lenient,
+            NerConfigUpdate.Builder::new);
+        parser.declareString(NerConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        return parser;
+    }
+
+    public static NerConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    private final String resultsField;
+
+    public NerConfigUpdate(String resultsField) {
+        this.resultsField = resultsField;
+    }
+
+    public NerConfigUpdate(StreamInput in) throws IOException {
+        this.resultsField = in.readOptionalString();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalString(resultsField);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (resultsField == null || resultsField.equals(originalConfig.getResultsField())) {
+            return originalConfig;
+        }
+
+        if (originalConfig instanceof NerConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a inference request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        NerConfig nerConfig = (NerConfig)originalConfig;
+        return new NerConfig(
+            nerConfig.getVocabularyConfig(),
+            nerConfig.getTokenization(),
+            nerConfig.getClassificationLabels(),
+            resultsField);
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof NerConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new NerConfigUpdate.Builder()
+            .setResultsField(resultsField);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        NerConfigUpdate that = (NerConfigUpdate) o;
+        return Objects.equals(resultsField, that.resultsField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(resultsField);
+    }
+
+    public static class Builder
+        implements InferenceConfigUpdate.Builder<NerConfigUpdate.Builder, NerConfigUpdate> {
+        private String resultsField;
+
+        @Override
+        public NerConfigUpdate.Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public NerConfigUpdate build() {
+            return new NerConfigUpdate(this.resultsField);
+        }
+    }
+}
+

+ 3 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java

@@ -14,6 +14,9 @@ public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParse
     ParseField VOCABULARY = new ParseField("vocabulary");
     ParseField TOKENIZATION = new ParseField("tokenization");
     ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
+    ParseField RESULTS_FIELD = new ParseField("results_field");
+    ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
+
 
     /**
      * @return the vocabulary configuration that allows retrieving it

+ 0 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfigUpdate.java

@@ -7,12 +7,8 @@
 
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
-import org.elasticsearch.common.xcontent.ParseField;
-
 public abstract class NlpConfigUpdate implements InferenceConfigUpdate {
 
-    static ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
-
     @Override
     public InferenceConfig toConfig() {
         throw new UnsupportedOperationException("cannot serialize to nodes before 7.8");

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java

@@ -63,4 +63,9 @@ public class NullInferenceConfig implements InferenceConfig {
     public boolean isAllocateOnly() {
         return false;
     }
+
+    @Override
+    public String getResultsField() {
+        return null;
+    }
 }

+ 21 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfig.java

@@ -39,7 +39,7 @@ public class PassThroughConfig implements NlpConfig {
 
     private static ConstructingObjectParser<PassThroughConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<PassThroughConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
+            a -> new PassThroughConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2]));
         parser.declareObject(
             ConstructingObjectParser.optionalConstructorArg(),
             (p, c) -> {
@@ -57,21 +57,28 @@ public class PassThroughConfig implements NlpConfig {
             ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
             TOKENIZATION
         );
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
     private final Tokenization tokenization;
+    private final String resultsField;
 
-    public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
+    public PassThroughConfig(@Nullable VocabularyConfig vocabularyConfig,
+                             @Nullable Tokenization tokenization,
+                             @Nullable String resultsField
+    ) {
         this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
             .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
+        this.resultsField = resultsField;
     }
 
     public PassThroughConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
         tokenization = in.readNamedWriteable(Tokenization.class);
+        resultsField = in.readOptionalString();
     }
 
     @Override
@@ -79,6 +86,9 @@ public class PassThroughConfig implements NlpConfig {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
         NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -92,6 +102,7 @@ public class PassThroughConfig implements NlpConfig {
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
         out.writeNamedWriteable(tokenization);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -121,12 +132,13 @@ public class PassThroughConfig implements NlpConfig {
 
         PassThroughConfig that = (PassThroughConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenization, that.tokenization);
+            && Objects.equals(tokenization, that.tokenization)
+            && Objects.equals(resultsField, that.resultsField);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization);
+        return Objects.hash(vocabularyConfig, tokenization, resultsField);
     }
 
     @Override
@@ -138,4 +150,9 @@ public class PassThroughConfig implements NlpConfig {
     public Tokenization getTokenization() {
         return tokenization;
     }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
 }

+ 151 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdate.java

@@ -0,0 +1,151 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+
+public class PassThroughConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+    public static final String NAME = PassThroughConfig.NAME;
+
+    public static PassThroughConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
+        }
+        return new PassThroughConfigUpdate(resultsField);
+    }
+
+    private static final ObjectParser<PassThroughConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<PassThroughConfigUpdate.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<PassThroughConfigUpdate.Builder, Void> parser = new ObjectParser<>(
+            NAME,
+            lenient,
+            PassThroughConfigUpdate.Builder::new);
+        parser.declareString(PassThroughConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        return parser;
+    }
+
+    public static PassThroughConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    private final String resultsField;
+
+    public PassThroughConfigUpdate(String resultsField) {
+        this.resultsField = resultsField;
+    }
+
+    public PassThroughConfigUpdate(StreamInput in) throws IOException {
+        this.resultsField = in.readOptionalString();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalString(resultsField);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (resultsField == null || resultsField.equals(originalConfig.getResultsField())) {
+            return originalConfig;
+        }
+
+        if (originalConfig instanceof PassThroughConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a inference request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        PassThroughConfig passThroughConfig = (PassThroughConfig)originalConfig;
+        return new PassThroughConfig(
+            passThroughConfig.getVocabularyConfig(),
+            passThroughConfig.getTokenization(),
+            resultsField);
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof PassThroughConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new PassThroughConfigUpdate.Builder()
+            .setResultsField(resultsField);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        PassThroughConfigUpdate that = (PassThroughConfigUpdate) o;
+        return Objects.equals(resultsField, that.resultsField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(resultsField);
+    }
+
+    public static class Builder
+        implements InferenceConfigUpdate.Builder<PassThroughConfigUpdate.Builder, PassThroughConfigUpdate> {
+        private String resultsField;
+
+        @Override
+        public PassThroughConfigUpdate.Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public PassThroughConfigUpdate build() {
+            return new PassThroughConfigUpdate(this.resultsField);
+        }
+    }
+}

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java

@@ -73,6 +73,7 @@ public class RegressionConfig implements LenientlyParsedInferenceConfig, Strictl
         return numTopFeatureImportanceValues;
     }
 
+    @Override
     public String getResultsField() {
         return resultsField;
     }

+ 97 - 24
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfig.java

@@ -10,8 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.ConstructingObjectParser;
-import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.Nullable;
@@ -20,7 +19,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
-import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
@@ -28,25 +26,24 @@ import java.util.Optional;
 public class TextClassificationConfig implements NlpConfig {
 
     public static final String NAME = "text_classification";
-    public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
 
     public static TextClassificationConfig fromXContentStrict(XContentParser parser) {
-        return STRICT_PARSER.apply(parser, null);
+        return STRICT_PARSER.apply(parser, null).build();
     }
 
     public static TextClassificationConfig fromXContentLenient(XContentParser parser) {
-        return LENIENT_PARSER.apply(parser, null);
+        return LENIENT_PARSER.apply(parser, null).build();
     }
 
-    private static final ConstructingObjectParser<TextClassificationConfig, Void> STRICT_PARSER = createParser(false);
-    private static final ConstructingObjectParser<TextClassificationConfig, Void> LENIENT_PARSER = createParser(true);
+    private static final ObjectParser<TextClassificationConfig.Builder, Void> STRICT_PARSER = createParser(false);
+    private static final ObjectParser<TextClassificationConfig.Builder, Void> LENIENT_PARSER = createParser(true);
+
+    private static ObjectParser<TextClassificationConfig.Builder, Void> createParser(boolean ignoreUnknownFields) {
+        ObjectParser<TextClassificationConfig.Builder, Void> parser =
+            new ObjectParser<>(NAME, ignoreUnknownFields, Builder::new);
 
-    @SuppressWarnings({ "unchecked"})
-    private static ConstructingObjectParser<TextClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
-        ConstructingObjectParser<TextClassificationConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new TextClassificationConfig((VocabularyConfig) a[0], (Tokenization) a[1], (List<String>) a[2], (Integer) a[3]));
         parser.declareObject(
-            ConstructingObjectParser.optionalConstructorArg(),
+            Builder::setVocabularyConfig,
             (p, c) -> {
                 if (ignoreUnknownFields == false) {
                     throw ExceptionsHelper.badRequestException(
@@ -59,11 +56,12 @@ public class TextClassificationConfig implements NlpConfig {
             VOCABULARY
         );
         parser.declareNamedObject(
-            ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
+            Builder::setTokenization, (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
                 TOKENIZATION
         );
-        parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), CLASSIFICATION_LABELS);
-        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
+        parser.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
+        parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        parser.declareString(Builder::setResultsField, RESULTS_FIELD);
         return parser;
     }
 
@@ -71,16 +69,32 @@ public class TextClassificationConfig implements NlpConfig {
     private final Tokenization tokenization;
     private final List<String> classificationLabels;
     private final int numTopClasses;
+    private final String resultsField;
 
     public TextClassificationConfig(@Nullable VocabularyConfig vocabularyConfig,
                                     @Nullable Tokenization tokenization,
-                                    @Nullable List<String> classificationLabels,
-                                    @Nullable Integer numTopClasses) {
+                                    List<String> classificationLabels,
+                                    @Nullable Integer numTopClasses,
+                                    @Nullable String resultsField) {
         this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
             .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
-        this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
+        if (classificationLabels == null || classificationLabels.size() < 2) {
+            throw ExceptionsHelper.badRequestException("[{}] requires at least 2 [{}]; provided {}",
+                NAME, CLASSIFICATION_LABELS, classificationLabels);
+        }
+        this.classificationLabels = classificationLabels;
         this.numTopClasses = Optional.ofNullable(numTopClasses).orElse(-1);
+        if (this.numTopClasses == 0) {
+            throw ExceptionsHelper.badRequestException(
+                    "[{}] requires at least 1 [{}]; provided [{}]",
+                    NAME,
+                    TextClassificationConfig.NUM_TOP_CLASSES,
+                    numTopClasses
+            );
+        }
+        this.resultsField = resultsField;
+
     }
 
     public TextClassificationConfig(StreamInput in) throws IOException {
@@ -88,6 +102,7 @@ public class TextClassificationConfig implements NlpConfig {
         tokenization = in.readNamedWriteable(Tokenization.class);
         classificationLabels = in.readStringList();
         numTopClasses = in.readInt();
+        resultsField = in.readOptionalString();
     }
 
     @Override
@@ -96,6 +111,7 @@ public class TextClassificationConfig implements NlpConfig {
         out.writeNamedWriteable(tokenization);
         out.writeStringCollection(classificationLabels);
         out.writeInt(numTopClasses);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -103,10 +119,11 @@ public class TextClassificationConfig implements NlpConfig {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
         NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
-        if (classificationLabels.isEmpty() == false) {
-            builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
-        }
+        builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
         builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -140,12 +157,13 @@ public class TextClassificationConfig implements NlpConfig {
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
             && Objects.equals(tokenization, that.tokenization)
             && Objects.equals(numTopClasses, that.numTopClasses)
-            && Objects.equals(classificationLabels, that.classificationLabels);
+            && Objects.equals(classificationLabels, that.classificationLabels)
+            && Objects.equals(resultsField, that.resultsField);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, numTopClasses);
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, numTopClasses, resultsField);
     }
 
     @Override
@@ -166,8 +184,63 @@ public class TextClassificationConfig implements NlpConfig {
         return numTopClasses;
     }
 
+    public String getResultsField() {
+        return resultsField;
+    }
+
     @Override
     public boolean isAllocateOnly() {
         return true;
     }
+
+    public static class Builder {
+        private VocabularyConfig vocabularyConfig;
+        private Tokenization tokenization;
+        private List<String> classificationLabels;
+        private int numTopClasses;
+        private String resultsField;
+
+        Builder() {}
+
+        Builder(TextClassificationConfig config) {
+            this.vocabularyConfig = config.vocabularyConfig;
+            this.tokenization = config.tokenization;
+            this.classificationLabels = config.classificationLabels;
+            this.numTopClasses = config.numTopClasses;
+            this.resultsField = config.resultsField;
+        }
+
+        public Builder setVocabularyConfig(VocabularyConfig vocabularyConfig) {
+            this.vocabularyConfig = vocabularyConfig;
+            return this;
+        }
+
+        public Builder setTokenization(Tokenization tokenization) {
+            this.tokenization = tokenization;
+            return this;
+        }
+
+        public Builder setClassificationLabels(List<String> classificationLabels) {
+            this.classificationLabels = classificationLabels;
+            return this;
+        }
+
+        public Builder setNumTopClasses(Integer numTopClasses) {
+            this.numTopClasses = numTopClasses;
+            return this;
+        }
+
+        public Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public TextClassificationConfig build() {
+            return new TextClassificationConfig(vocabularyConfig,
+                tokenization,
+                classificationLabels,
+                numTopClasses,
+                resultsField);
+        }
+    }
 }

+ 211 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdate.java

@@ -0,0 +1,211 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig.CLASSIFICATION_LABELS;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig.NUM_TOP_CLASSES;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig.RESULTS_FIELD;
+
+public class TextClassificationConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+
+    public static final String NAME = TextClassificationConfig.NAME;
+
+    @SuppressWarnings("unchecked")
+    public static TextClassificationConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
+        String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+        List<String> classificationLabels = (List<String>)options.remove(CLASSIFICATION_LABELS.getPreferredName());
+
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
+        }
+        return new TextClassificationConfigUpdate(classificationLabels, numTopClasses, resultsField);
+    }
+
+    private static final ObjectParser<TextClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<TextClassificationConfigUpdate.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<Builder, Void> parser = new ObjectParser<>(
+            NAME,
+            lenient,
+            TextClassificationConfigUpdate.Builder::new);
+        parser.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
+        parser.declareString(Builder::setResultsField, RESULTS_FIELD);
+        parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        return parser;
+    }
+
+    public static TextClassificationConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    private final List<String> classificationLabels;
+    private final Integer numTopClasses;
+    private final String resultsField;
+
+    public TextClassificationConfigUpdate(List<String> classificationLabels, Integer numTopClasses, String resultsField) {
+        this.classificationLabels = classificationLabels;
+        this.numTopClasses = numTopClasses;
+        this.resultsField = resultsField;
+    }
+
+    public TextClassificationConfigUpdate(StreamInput in) throws IOException {
+        classificationLabels = in.readOptionalStringList();
+        numTopClasses = in.readOptionalVInt();
+        resultsField = in.readOptionalString();
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalStringCollection(classificationLabels);
+        out.writeOptionalVInt(numTopClasses);
+        out.writeOptionalString(resultsField);
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (originalConfig instanceof TextClassificationConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        TextClassificationConfig classificationConfig = (TextClassificationConfig)originalConfig;
+        if (isNoop(classificationConfig)) {
+            return originalConfig;
+        }
+
+        TextClassificationConfig.Builder builder = new TextClassificationConfig.Builder(classificationConfig);
+        if (numTopClasses != null) {
+            builder.setNumTopClasses(numTopClasses);
+        }
+        if (classificationLabels != null) {
+            if (classificationLabels.size() != classificationConfig.getClassificationLabels().size()) {
+                throw ExceptionsHelper.badRequestException(
+                    "The number of [{}] the model is defined with [{}] does not match the number in the update [{}]",
+                    CLASSIFICATION_LABELS,
+                    classificationConfig.getClassificationLabels().size(),
+                    classificationLabels.size()
+                );
+            }
+            builder.setClassificationLabels(classificationLabels);
+        }
+        if (resultsField != null) {
+            builder.setResultsField(resultsField);
+        }
+        return builder.build();
+    }
+
+    boolean isNoop(TextClassificationConfig originalConfig) {
+        return (this.numTopClasses == null || this.numTopClasses == originalConfig.getNumTopClasses()) &&
+            (this.classificationLabels == null) &&
+            (this.resultsField == null || this.resultsField.equals(originalConfig.getResultsField()));
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof TextClassificationConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder()
+            .setClassificationLabels(classificationLabels)
+            .setNumTopClasses(numTopClasses)
+            .setResultsField(resultsField);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (numTopClasses != null) {
+            builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
+        }
+        if (classificationLabels != null) {
+            builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        }
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TextClassificationConfigUpdate that = (TextClassificationConfigUpdate) o;
+        return Objects.equals(classificationLabels, that.classificationLabels) &&
+            Objects.equals(numTopClasses, that.numTopClasses) &&
+            Objects.equals(resultsField, that.resultsField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(classificationLabels, numTopClasses, resultsField);
+    }
+
+    public static class Builder
+        implements InferenceConfigUpdate.Builder<TextClassificationConfigUpdate.Builder, TextClassificationConfigUpdate> {
+        private List<String> classificationLabels;
+        private Integer numTopClasses;
+        private String resultsField;
+
+        public TextClassificationConfigUpdate.Builder setNumTopClasses(Integer numTopClasses) {
+            this.numTopClasses = numTopClasses;
+            return this;
+        }
+
+        public TextClassificationConfigUpdate.Builder setClassificationLabels(List<String> classificationLabels) {
+            this.classificationLabels = classificationLabels;
+            return this;
+        }
+
+        @Override
+        public TextClassificationConfigUpdate.Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public TextClassificationConfigUpdate build() {
+            return new TextClassificationConfigUpdate(this.classificationLabels, this.numTopClasses, this.resultsField);
+        }
+    }
+}

+ 20 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfig.java

@@ -39,7 +39,7 @@ public class TextEmbeddingConfig implements NlpConfig {
 
     private static ConstructingObjectParser<TextEmbeddingConfig, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<TextEmbeddingConfig, Void> parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields,
-            a -> new TextEmbeddingConfig((VocabularyConfig) a[0], (Tokenization) a[1]));
+            a -> new TextEmbeddingConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2]));
         parser.declareObject(
             ConstructingObjectParser.optionalConstructorArg(),
             (p, c) -> {
@@ -57,21 +57,27 @@ public class TextEmbeddingConfig implements NlpConfig {
             ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
             TOKENIZATION
         );
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
         return parser;
     }
 
     private final VocabularyConfig vocabularyConfig;
     private final Tokenization tokenization;
+    private final String resultsField;
 
-    public TextEmbeddingConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization) {
+    public TextEmbeddingConfig(@Nullable VocabularyConfig vocabularyConfig,
+                               @Nullable Tokenization tokenization,
+                               @Nullable String resultsField) {
         this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
             .orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
         this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
+        this.resultsField = resultsField;
     }
 
     public TextEmbeddingConfig(StreamInput in) throws IOException {
         vocabularyConfig = new VocabularyConfig(in);
         tokenization = in.readNamedWriteable(Tokenization.class);
+        resultsField = in.readOptionalString();
     }
 
     @Override
@@ -79,6 +85,9 @@ public class TextEmbeddingConfig implements NlpConfig {
         builder.startObject();
         builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
         NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -92,6 +101,7 @@ public class TextEmbeddingConfig implements NlpConfig {
     public void writeTo(StreamOutput out) throws IOException {
         vocabularyConfig.writeTo(out);
         out.writeNamedWriteable(tokenization);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -121,12 +131,13 @@ public class TextEmbeddingConfig implements NlpConfig {
 
         TextEmbeddingConfig that = (TextEmbeddingConfig) o;
         return Objects.equals(vocabularyConfig, that.vocabularyConfig)
-            && Objects.equals(tokenization, that.tokenization);
+            && Objects.equals(tokenization, that.tokenization)
+            && Objects.equals(resultsField, that.resultsField);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization);
+        return Objects.hash(vocabularyConfig, tokenization, resultsField);
     }
 
     @Override
@@ -138,4 +149,9 @@ public class TextEmbeddingConfig implements NlpConfig {
     public Tokenization getTokenization() {
         return tokenization;
     }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
 }

+ 152 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java

@@ -0,0 +1,152 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
+
+public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedXContentObject {
+
+    public static final String NAME = TextEmbeddingConfig.NAME;
+
+    public static TextEmbeddingConfigUpdate fromMap(Map<String, Object> map) {
+        Map<String, Object> options = new HashMap<>(map);
+        String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+
+        if (options.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
+        }
+        return new TextEmbeddingConfigUpdate(resultsField);
+    }
+
+    private static final ObjectParser<TextEmbeddingConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<TextEmbeddingConfigUpdate.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<TextEmbeddingConfigUpdate.Builder, Void> parser = new ObjectParser<>(
+            NAME,
+            lenient,
+            TextEmbeddingConfigUpdate.Builder::new);
+        parser.declareString(TextEmbeddingConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
+        return parser;
+    }
+
+    public static TextEmbeddingConfigUpdate fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    private final String resultsField;
+
+    public TextEmbeddingConfigUpdate(String resultsField) {
+        this.resultsField = resultsField;
+    }
+
+    public TextEmbeddingConfigUpdate(StreamInput in) throws IOException {
+        this.resultsField = in.readOptionalString();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalString(resultsField);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (resultsField == null || resultsField.equals(originalConfig.getResultsField())) {
+            return originalConfig;
+        }
+
+        if (originalConfig instanceof TextEmbeddingConfig == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of type [{}] can not be updated with a inference request of type [{}]",
+                originalConfig.getName(),
+                getName());
+        }
+
+        TextEmbeddingConfig embeddingConfig = (TextEmbeddingConfig)originalConfig;
+        return new TextEmbeddingConfig(
+            embeddingConfig.getVocabularyConfig(),
+            embeddingConfig.getTokenization(),
+            resultsField);
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return config instanceof TextEmbeddingConfig;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder()
+            .setResultsField(resultsField);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TextEmbeddingConfigUpdate that = (TextEmbeddingConfigUpdate) o;
+        return Objects.equals(resultsField, that.resultsField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(resultsField);
+    }
+
+    public static class Builder
+        implements InferenceConfigUpdate.Builder<TextEmbeddingConfigUpdate.Builder, TextEmbeddingConfigUpdate> {
+        private String resultsField;
+
+        @Override
+        public Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        public TextEmbeddingConfigUpdate build() {
+            return new TextEmbeddingConfigUpdate(this.resultsField);
+        }
+    }
+}

+ 20 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfig.java

@@ -65,7 +65,8 @@ public class ZeroShotClassificationConfig implements NlpConfig {
                 (Tokenization) a[2],
                 (String) a[3],
                 (Boolean) a[4],
-                (List<String>) a[5]
+                (List<String>) a[5],
+                (String) a[6]
             )
         );
         parser.declareStringArray(ConstructingObjectParser.constructorArg(), CLASSIFICATION_LABELS);
@@ -89,6 +90,7 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         parser.declareString(ConstructingObjectParser.optionalConstructorArg(), HYPOTHESIS_TEMPLATE);
         parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
         parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
+        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
         return parser;
     }
 
@@ -98,6 +100,7 @@ public class ZeroShotClassificationConfig implements NlpConfig {
     private final List<String> labels;
     private final boolean isMultiLabel;
     private final String hypothesisTemplate;
+    private final String resultsField;
 
     public ZeroShotClassificationConfig(
         List<String> classificationLabels,
@@ -105,7 +108,8 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         @Nullable Tokenization tokenization,
         @Nullable String hypothesisTemplate,
         @Nullable Boolean isMultiLabel,
-        @Nullable List<String> labels
+        @Nullable List<String> labels,
+        @Nullable String resultsField
     ) {
         this.classificationLabels = ExceptionsHelper.requireNonNull(classificationLabels, CLASSIFICATION_LABELS);
         if (this.classificationLabels.size() != 3) {
@@ -136,6 +140,7 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         if (labels != null && labels.isEmpty()) {
             throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
         }
+        this.resultsField = resultsField;
     }
 
     public ZeroShotClassificationConfig(StreamInput in) throws IOException {
@@ -145,6 +150,7 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         isMultiLabel = in.readBoolean();
         hypothesisTemplate = in.readString();
         labels = in.readOptionalStringList();
+        resultsField = in.readOptionalString();
     }
 
     @Override
@@ -155,6 +161,7 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         out.writeBoolean(isMultiLabel);
         out.writeString(hypothesisTemplate);
         out.writeOptionalStringCollection(labels);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -168,6 +175,9 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         if (labels != null) {
             builder.field(LABELS.getPreferredName(), labels);
         }
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -203,12 +213,13 @@ public class ZeroShotClassificationConfig implements NlpConfig {
             && Objects.equals(isMultiLabel, that.isMultiLabel)
             && Objects.equals(hypothesisTemplate, that.hypothesisTemplate)
             && Objects.equals(labels, that.labels)
-            && Objects.equals(classificationLabels, that.classificationLabels);
+            && Objects.equals(classificationLabels, that.classificationLabels)
+            && Objects.equals(resultsField, that.resultsField);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, hypothesisTemplate, isMultiLabel, labels);
+        return Objects.hash(vocabularyConfig, tokenization, classificationLabels, hypothesisTemplate, isMultiLabel, labels, resultsField);
     }
 
     @Override
@@ -237,6 +248,11 @@ public class ZeroShotClassificationConfig implements NlpConfig {
         return Optional.ofNullable(labels).orElse(List.of());
     }
 
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
     @Override
     public boolean isAllocateOnly() {
         return true;

+ 27 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdate.java

@@ -23,6 +23,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig.RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.LABELS;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig.MULTI_LABEL;
 
@@ -39,46 +40,53 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         Map<String, Object> options = new HashMap<>(map);
         Boolean isMultiLabel = (Boolean)options.remove(MULTI_LABEL.getPreferredName());
         List<String> labels = (List<String>)options.remove(LABELS.getPreferredName());
+        String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
         }
-        return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel);
+        return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel, resultsField);
     }
 
     @SuppressWarnings({ "unchecked"})
     private static final ConstructingObjectParser<ZeroShotClassificationConfigUpdate, Void> STRICT_PARSER = new ConstructingObjectParser<>(
         NAME,
-        a -> new ZeroShotClassificationConfigUpdate((List<String>)a[0], (Boolean) a[1])
+        a -> new ZeroShotClassificationConfigUpdate((List<String>)a[0], (Boolean) a[1], (String) a[2])
     );
 
     static {
         STRICT_PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
         STRICT_PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
+        STRICT_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
     }
 
     private final List<String> labels;
     private final Boolean isMultiLabel;
+    private final String resultsField;
 
     public ZeroShotClassificationConfigUpdate(
         @Nullable List<String> labels,
-        @Nullable Boolean isMultiLabel
+        @Nullable Boolean isMultiLabel,
+        @Nullable String resultsField
     ) {
         this.labels = labels;
         if (labels != null && labels.isEmpty()) {
             throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
         }
         this.isMultiLabel = isMultiLabel;
+        this.resultsField = resultsField;
     }
 
     public ZeroShotClassificationConfigUpdate(StreamInput in) throws IOException {
         labels = in.readOptionalStringList();
         isMultiLabel = in.readOptionalBoolean();
+        resultsField = in.readOptionalString();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeOptionalStringCollection(labels);
         out.writeOptionalBoolean(isMultiLabel);
+        out.writeOptionalString(resultsField);
     }
 
     @Override
@@ -90,6 +98,9 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         if (isMultiLabel != null) {
             builder.field(MULTI_LABEL.getPreferredName(), isMultiLabel);
         }
+        if (resultsField != null) {
+            builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        }
         builder.endObject();
         return builder;
     }
@@ -125,13 +136,15 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
             zeroShotConfig.getTokenization(),
             zeroShotConfig.getHypothesisTemplate(),
             Optional.ofNullable(isMultiLabel).orElse(zeroShotConfig.isMultiLabel()),
-            Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels())
+            Optional.ofNullable(labels).orElse(zeroShotConfig.getLabels()),
+            Optional.ofNullable(resultsField).orElse(zeroShotConfig.getResultsField())
         );
     }
 
     boolean isNoop(ZeroShotClassificationConfig originalConfig) {
         return (labels == null || labels.equals(originalConfig.getClassificationLabels()))
-            && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()));
+            && (isMultiLabel == null || isMultiLabel.equals(originalConfig.isMultiLabel()))
+            && (resultsField == null || resultsField.equals(originalConfig.getResultsField()));
     }
 
     @Override
@@ -141,7 +154,7 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
 
     @Override
     public String getResultsField() {
-        return null;
+        return resultsField;
     }
 
     @Override
@@ -160,12 +173,14 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         if (o == null || getClass() != o.getClass()) return false;
 
         ZeroShotClassificationConfigUpdate that = (ZeroShotClassificationConfigUpdate) o;
-        return Objects.equals(isMultiLabel, that.isMultiLabel) && Objects.equals(labels, that.labels);
+        return Objects.equals(isMultiLabel, that.isMultiLabel) &&
+            Objects.equals(labels, that.labels) &&
+            Objects.equals(resultsField, that.resultsField);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(labels, isMultiLabel);
+        return Objects.hash(labels, isMultiLabel, resultsField);
     }
 
     public List<String> getLabels() {
@@ -178,10 +193,12 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         > {
         private List<String> labels;
         private Boolean isMultiLabel;
+        private String resultsField;
 
         @Override
         public ZeroShotClassificationConfigUpdate.Builder setResultsField(String resultsField) {
-            throw new IllegalArgumentException();
+            this.resultsField = resultsField;
+            return this;
         }
 
         public Builder setLabels(List<String> labels) {
@@ -195,7 +212,7 @@ public class ZeroShotClassificationConfigUpdate extends NlpConfigUpdate implemen
         }
 
         public ZeroShotClassificationConfigUpdate build() {
-            return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel);
+            return new ZeroShotClassificationConfigUpdate(labels, isMultiLabel, resultsField);
         }
     }
 }

+ 3 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigTests.java

@@ -50,7 +50,9 @@ public class FillMaskConfigTests extends InferenceConfigItemTestCase<FillMaskCon
     public static FillMaskConfig createRandom() {
         return new FillMaskConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom()
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomInt(),
+            randomBoolean() ? null : randomAlphaOfLength(5)
         );
     }
 }

+ 103 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java

@@ -0,0 +1,103 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class FillMaskConfigUpdateTests extends AbstractBWCSerializationTestCase<FillMaskConfigUpdate> {
+
+    public void testFromMap() {
+        FillMaskConfigUpdate expected = new FillMaskConfigUpdate(3, "ml-results");
+        Map<String, Object> config = new HashMap<>(){{
+            put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+            put(NlpConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
+        }};
+        assertThat(FillMaskConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> FillMaskConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+    public void testIsNoop() {
+        assertTrue(new FillMaskConfigUpdate.Builder().build().isNoop(FillMaskConfigTests.createRandom()));
+
+        assertFalse(new FillMaskConfigUpdate.Builder()
+            .setResultsField("foo")
+            .build()
+            .isNoop(new FillMaskConfig.Builder().setResultsField("bar").build()));
+
+        assertTrue(new FillMaskConfigUpdate.Builder()
+            .setNumTopClasses(3)
+            .build()
+            .isNoop(new FillMaskConfig.Builder().setNumTopClasses(3).build()));
+    }
+
+    public void testApply() {
+        FillMaskConfig originalConfig = FillMaskConfigTests.createRandom();
+
+        assertThat(originalConfig, equalTo(new FillMaskConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(new FillMaskConfig.Builder(originalConfig)
+                .setResultsField("ml-results")
+                .build(),
+            equalTo(new FillMaskConfigUpdate.Builder()
+                .setResultsField("ml-results")
+                .build()
+                .apply(originalConfig)
+            ));
+        assertThat(new FillMaskConfig.Builder(originalConfig)
+                .setNumTopClasses(originalConfig.getNumTopClasses() +1)
+                .build(),
+            equalTo(new FillMaskConfigUpdate.Builder()
+                .setNumTopClasses(originalConfig.getNumTopClasses() +1)
+                .build()
+                .apply(originalConfig)
+            ));
+    }
+
+    @Override
+    protected FillMaskConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return FillMaskConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<FillMaskConfigUpdate> instanceReader() {
+        return FillMaskConfigUpdate::new;
+    }
+
+    @Override
+    protected FillMaskConfigUpdate createTestInstance() {
+        FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder();
+        if (randomBoolean()) {
+            builder.setNumTopClasses(randomIntBetween(1, 4));
+        }
+        if (randomBoolean()) {
+            builder.setResultsField(randomAlphaOfLength(8));
+        }
+        return builder.build();
+    }
+
+    @Override
+    protected FillMaskConfigUpdate mutateInstanceForVersion(FillMaskConfigUpdate instance, Version version) {
+        return instance;
+    }
+}

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigTests.java

@@ -51,7 +51,8 @@ public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {
         return new NerConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : BertTokenizationTests.createRandom(),
-            randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10))
+            randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomAlphaOfLength(5)
         );
     }
 }

+ 82 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java

@@ -0,0 +1,82 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class NerConfigUpdateTests extends AbstractBWCSerializationTestCase<NerConfigUpdate> {
+
+    public void testFromMap() {
+        NerConfigUpdate expected = new NerConfigUpdate("ml-results");
+        Map<String, Object> config = new HashMap<>(){{
+            put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+        }};
+        assertThat(NerConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> NerConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+
+    public void testApply() {
+        NerConfig originalConfig = NerConfigTests.createRandom();
+
+        assertThat(originalConfig, sameInstance(new NerConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(new NerConfig(
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                originalConfig.getClassificationLabels(),
+                "ml-results"),
+            equalTo(new NerConfigUpdate.Builder()
+                .setResultsField("ml-results")
+                .build()
+                .apply(originalConfig)
+            ));
+    }
+
+    @Override
+    protected NerConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return NerConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<NerConfigUpdate> instanceReader() {
+        return NerConfigUpdate::new;
+    }
+
+    @Override
+    protected NerConfigUpdate createTestInstance() {
+        NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder();
+        if (randomBoolean()) {
+            builder.setResultsField(randomAlphaOfLength(8));
+        }
+        return builder.build();
+    }
+
+    @Override
+    protected NerConfigUpdate mutateInstanceForVersion(NerConfigUpdate instance, Version version) {
+        return instance;
+    }
+}
+

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigTests.java

@@ -50,7 +50,8 @@ public class PassThroughConfigTests extends InferenceConfigItemTestCase<PassThro
     public static PassThroughConfig createRandom() {
         return new PassThroughConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom()
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomAlphaOfLength(7)
         );
     }
 }

+ 80 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java

@@ -0,0 +1,80 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class PassThroughConfigUpdateTests extends AbstractBWCSerializationTestCase<PassThroughConfigUpdate> {
+
+    public void testFromMap() {
+        PassThroughConfigUpdate expected = new PassThroughConfigUpdate("ml-results");
+        Map<String, Object> config = new HashMap<>(){{
+            put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+        }};
+        assertThat(PassThroughConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> PassThroughConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+
+    public void testApply() {
+        PassThroughConfig originalConfig = PassThroughConfigTests.createRandom();
+
+        assertThat(originalConfig, sameInstance(new PassThroughConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(new PassThroughConfig(
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                "ml-results"),
+            equalTo(new PassThroughConfigUpdate.Builder()
+                .setResultsField("ml-results")
+                .build()
+                .apply(originalConfig)
+            ));
+    }
+
+    @Override
+    protected PassThroughConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return PassThroughConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<PassThroughConfigUpdate> instanceReader() {
+        return PassThroughConfigUpdate::new;
+    }
+
+    @Override
+    protected PassThroughConfigUpdate createTestInstance() {
+        PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder();
+        if (randomBoolean()) {
+            builder.setResultsField(randomAlphaOfLength(8));
+        }
+        return builder.build();
+    }
+
+    @Override
+    protected PassThroughConfigUpdate mutateInstanceForVersion(PassThroughConfigUpdate instance, Version version) {
+        return instance;
+    }
+}

+ 28 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigTests.java

@@ -7,14 +7,18 @@
 
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.function.Predicate;
 
+import static org.hamcrest.Matchers.containsString;
+
 public class TextClassificationConfigTests extends InferenceConfigItemTestCase<TextClassificationConfig> {
 
     @Override
@@ -47,12 +51,34 @@ public class TextClassificationConfigTests extends InferenceConfigItemTestCase<T
         return instance;
     }
 
+    public void testInvalidClassificationLabels() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new TextClassificationConfig(null, null, null, null, null));
+
+        assertThat(e.getMessage(),
+            containsString("[text_classification] requires at least 2 [classification_labels]; provided null"));
+
+        e = expectThrows(ElasticsearchStatusException.class,
+            () -> new TextClassificationConfig(null, null, List.of("too-few"), null, null));
+        assertThat(e.getMessage(),
+            containsString("[text_classification] requires at least 2 [classification_labels]; provided [too-few]"));
+    }
+
+    public void testInvalidNumClasses() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new TextClassificationConfig(null, null, List.of("one", "two"), 0, null));
+        assertThat(e.getMessage(),
+            containsString("[text_classification] requires at least 1 [num_top_classes]; provided [0]"));
+    }
+
+
     public static TextClassificationConfig createRandom() {
         return new TextClassificationConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
             randomBoolean() ? null : BertTokenizationTests.createRandom(),
-            randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
-            randomBoolean() ? null : randomIntBetween(-1, 10)
+            randomList(2, 5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomBoolean() ? -1 : randomIntBetween(1, 10),
+            randomBoolean() ? null : randomAlphaOfLength(6)
         );
     }
 }

+ 148 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java

@@ -0,0 +1,148 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+
+public class TextClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<TextClassificationConfigUpdate> {
+
+    public void testFromMap() {
+        TextClassificationConfigUpdate expected = new TextClassificationConfigUpdate(List.of("foo", "bar"), 3, "ml-results");
+        Map<String, Object> config = new HashMap<>(){{
+            put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+            put(NlpConfig.CLASSIFICATION_LABELS.getPreferredName(), List.of("foo", "bar"));
+            put(NlpConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
+        }};
+        assertThat(TextClassificationConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> TextClassificationConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+    public void testIsNoop() {
+        assertTrue(new TextClassificationConfigUpdate.Builder().build().isNoop(TextClassificationConfigTests.createRandom()));
+
+        assertFalse(new TextClassificationConfigUpdate.Builder()
+            .setResultsField("foo")
+            .build()
+            .isNoop(new TextClassificationConfig.Builder()
+                .setClassificationLabels(List.of("a", "b"))
+                .setNumTopClasses(-1)
+                .setResultsField("bar").build()));
+
+        assertTrue(new TextClassificationConfigUpdate.Builder()
+            .setNumTopClasses(3)
+            .build()
+            .isNoop(new TextClassificationConfig.Builder().setClassificationLabels(List.of("a", "b")).setNumTopClasses(3).build()));
+        assertFalse(new TextClassificationConfigUpdate.Builder()
+            .setClassificationLabels(List.of("a", "b"))
+            .build()
+            .isNoop(new TextClassificationConfig.Builder().setClassificationLabels(List.of("c", "d")).setNumTopClasses(3).build()));
+    }
+
+    public void testApply() {
+        TextClassificationConfig originalConfig = new TextClassificationConfig(
+            VocabularyConfigTests.createRandom(),
+            BertTokenizationTests.createRandom(),
+            List.of("one", "two"),
+            randomIntBetween(-1, 10),
+            "foo-results"
+        );
+
+        assertThat(originalConfig, equalTo(new TextClassificationConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(new TextClassificationConfig.Builder(originalConfig)
+            .setClassificationLabels(List.of("foo", "bar"))
+            .build(),
+            equalTo(new TextClassificationConfigUpdate.Builder()
+                .setClassificationLabels(List.of("foo", "bar"))
+                .build()
+                .apply(originalConfig)));
+        assertThat(new TextClassificationConfig.Builder(originalConfig)
+                .setResultsField("ml-results")
+                .build(),
+            equalTo(new TextClassificationConfigUpdate.Builder()
+                .setResultsField("ml-results")
+                .build()
+                .apply(originalConfig)
+            ));
+        assertThat(new TextClassificationConfig.Builder(originalConfig)
+                .setNumTopClasses(originalConfig.getNumTopClasses() +1)
+                .build(),
+            equalTo(new TextClassificationConfigUpdate.Builder()
+                .setNumTopClasses(originalConfig.getNumTopClasses() +1)
+                .build()
+                .apply(originalConfig)
+            ));
+    }
+
+    public void testApplyWithInvalidLabels() {
+        TextClassificationConfig originalConfig = TextClassificationConfigTests.createRandom();
+
+        int numberNewLabels = originalConfig.getClassificationLabels().size() +1;
+        List<String> newLabels = randomList(numberNewLabels, numberNewLabels, () -> randomAlphaOfLength(6));
+
+        var update = new TextClassificationConfigUpdate.Builder()
+            .setClassificationLabels(newLabels)
+            .build();
+
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> update.apply(originalConfig));
+        assertThat(e.getMessage(),
+            containsString("The number of [classification_labels] the model is defined with ["
+                + originalConfig.getClassificationLabels().size() +
+                "] does not match the number in the update [" + numberNewLabels + "]"));
+    }
+
+    @Override
+    protected TextClassificationConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return TextClassificationConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<TextClassificationConfigUpdate> instanceReader() {
+        return TextClassificationConfigUpdate::new;
+    }
+
+    @Override
+    protected TextClassificationConfigUpdate createTestInstance() {
+        TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder();
+        if (randomBoolean()) {
+            builder.setNumTopClasses(randomIntBetween(1, 4));
+        }
+        if (randomBoolean()) {
+            builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4)));
+        }
+        if (randomBoolean()) {
+            builder.setResultsField(randomAlphaOfLength(8));
+        }
+        return builder.build();
+    }
+
+    @Override
+    protected TextClassificationConfigUpdate mutateInstanceForVersion(TextClassificationConfigUpdate instance, Version version) {
+        return instance;
+    }
+}

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigTests.java

@@ -50,7 +50,8 @@ public class TextEmbeddingConfigTests extends InferenceConfigItemTestCase<TextEm
     public static TextEmbeddingConfig createRandom() {
         return new TextEmbeddingConfig(
             randomBoolean() ? null : VocabularyConfigTests.createRandom(),
-            randomBoolean() ? null : BertTokenizationTests.createRandom()
+            randomBoolean() ? null : BertTokenizationTests.createRandom(),
+            randomBoolean() ? null : randomAlphaOfLength(7)
         );
     }
 }

+ 80 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java

@@ -0,0 +1,80 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class TextEmbeddingConfigUpdateTests extends AbstractBWCSerializationTestCase<TextEmbeddingConfigUpdate> {
+
+    public void testFromMap() {
+        TextEmbeddingConfigUpdate expected = new TextEmbeddingConfigUpdate("ml-results");
+        Map<String, Object> config = new HashMap<>(){{
+            put(NlpConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
+        }};
+        assertThat(TextEmbeddingConfigUpdate.fromMap(config), equalTo(expected));
+    }
+
+    public void testFromMapWithUnknownField() {
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> TextEmbeddingConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
+        assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
+    }
+
+
+    public void testApply() {
+        TextEmbeddingConfig originalConfig = TextEmbeddingConfigTests.createRandom();
+
+        assertThat(originalConfig, sameInstance(new TextEmbeddingConfigUpdate.Builder().build().apply(originalConfig)));
+
+        assertThat(new TextEmbeddingConfig(
+            originalConfig.getVocabularyConfig(),
+            originalConfig.getTokenization(),
+            "ml-results"),
+            equalTo(new TextEmbeddingConfigUpdate.Builder()
+                .setResultsField("ml-results")
+                .build()
+                .apply(originalConfig)
+            ));
+    }
+
+    @Override
+    protected TextEmbeddingConfigUpdate doParseInstance(XContentParser parser) throws IOException {
+        return TextEmbeddingConfigUpdate.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<TextEmbeddingConfigUpdate> instanceReader() {
+        return TextEmbeddingConfigUpdate::new;
+    }
+
+    @Override
+    protected TextEmbeddingConfigUpdate createTestInstance() {
+        TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder();
+        if (randomBoolean()) {
+            builder.setResultsField(randomAlphaOfLength(8));
+        }
+        return builder.build();
+    }
+
+    @Override
+    protected TextEmbeddingConfigUpdate mutateInstanceForVersion(TextEmbeddingConfigUpdate instance, Version version) {
+        return instance;
+    }
+}

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigTests.java

@@ -55,7 +55,8 @@ public class ZeroShotClassificationConfigTests extends InferenceConfigItemTestCa
             randomBoolean() ? null : BertTokenizationTests.createRandom(),
             randomAlphaOfLength(10),
             randomBoolean(),
-            randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10))
+            randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomAlphaOfLength(7)
         );
     }
 }

+ 27 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java

@@ -50,10 +50,11 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
     }
 
     public void testFromMap() {
-        ZeroShotClassificationConfigUpdate expected = new ZeroShotClassificationConfigUpdate(List.of("foo", "bar"), false);
+        ZeroShotClassificationConfigUpdate expected = new ZeroShotClassificationConfigUpdate(List.of("foo", "bar"), false, "ml-results");
         Map<String, Object> config = new HashMap<>(){{
             put(ZeroShotClassificationConfig.LABELS.getPreferredName(), List.of("foo", "bar"));
             put(ZeroShotClassificationConfig.MULTI_LABEL.getPreferredName(), false);
+            put(ZeroShotClassificationConfig.RESULTS_FIELD.getPreferredName(), "ml-results");
         }};
         assertThat(ZeroShotClassificationConfigUpdate.fromMap(config), equalTo(expected));
     }
@@ -71,7 +72,8 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
             randomBoolean() ? null : BertTokenizationTests.createRandom(),
             randomAlphaOfLength(10),
             randomBoolean(),
-            randomList(1, 5, () -> randomAlphaOfLength(10))
+            randomList(1, 5, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null : randomAlphaOfLength(8)
         );
 
         assertThat(originalConfig, equalTo(new ZeroShotClassificationConfigUpdate.Builder().build().apply(originalConfig)));
@@ -83,7 +85,8 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
                 originalConfig.getTokenization(),
                 originalConfig.getHypothesisTemplate(),
                 originalConfig.isMultiLabel(),
-                List.of("foo", "bar")
+                List.of("foo", "bar"),
+                originalConfig.getResultsField()
             ),
             equalTo(
                 new ZeroShotClassificationConfigUpdate.Builder()
@@ -98,7 +101,8 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
                 originalConfig.getTokenization(),
                 originalConfig.getHypothesisTemplate(),
                 true,
-                originalConfig.getLabels()
+                originalConfig.getLabels(),
+                originalConfig.getResultsField()
             ),
             equalTo(
                 new ZeroShotClassificationConfigUpdate.Builder()
@@ -106,6 +110,22 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
                     .apply(originalConfig)
             )
         );
+        assertThat(
+            new ZeroShotClassificationConfig(
+                originalConfig.getClassificationLabels(),
+                originalConfig.getVocabularyConfig(),
+                originalConfig.getTokenization(),
+                originalConfig.getHypothesisTemplate(),
+                originalConfig.isMultiLabel(),
+                originalConfig.getLabels(),
+                "updated-field"
+            ),
+            equalTo(
+                new ZeroShotClassificationConfigUpdate.Builder()
+                    .setResultsField("updated-field").build()
+                    .apply(originalConfig)
+            )
+        );
     }
 
     public void testApplyWithEmptyLabelsInConfigAndUpdate() {
@@ -115,6 +135,7 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
             randomBoolean() ? null : BertTokenizationTests.createRandom(),
             randomAlphaOfLength(10),
             randomBoolean(),
+            null,
             null
         );
 
@@ -128,7 +149,8 @@ public class ZeroShotClassificationConfigUpdateTests extends InferenceConfigItem
     public static ZeroShotClassificationConfigUpdate createRandom() {
         return new ZeroShotClassificationConfigUpdate(
             randomBoolean() ? null : randomList(1,5, () -> randomAlphaOfLength(10)),
-            randomBoolean() ? null : randomBoolean()
+            randomBoolean() ? null : randomBoolean(),
+            randomBoolean() ? null : randomAlphaOfLength(5)
         );
     }
 }

+ 1 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java

@@ -276,7 +276,7 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
             new PutTrainedModelAction.Request(
                 TrainedModelConfig.builder()
                     .setModelType(TrainedModelType.PYTORCH)
-                    .setInferenceConfig(new PassThroughConfig(null, new BertTokenization(null, false, null)))
+                    .setInferenceConfig(new PassThroughConfig(null, new BertTokenization(null, false, null), null))
                     .setModelId(modelId)
                     .build(),
                 false

+ 2 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

@@ -200,7 +200,8 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
                     .setInferenceConfig(
                         new PassThroughConfig(
                             null,
-                            new BertTokenization(null, false, null)
+                            new BertTokenization(null, false, null),
+                            null
                         )
                     )
                     .setModelId(TRAINED_MODEL_ID)

+ 2 - 1
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java

@@ -72,7 +72,8 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase {
                             new VocabularyConfig(
                                 InferenceIndexConstants.nativeDefinitionStore()
                             ),
-                            new BertTokenization(null, false, null)
+                            new BertTokenization(null, false, null),
+                            null
                         )
                     )
                     .setModelId(modelId)

+ 27 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -33,10 +33,20 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@@ -357,13 +367,29 @@ public class InferenceProcessor extends AbstractProcessor {
             if (configMap.containsKey(ClassificationConfig.NAME.getPreferredName())) {
                 checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
                 return ClassificationConfigUpdate.fromMap(valueMap);
+            } else if (configMap.containsKey(FillMaskConfig.NAME)) {
+                checkSupportedVersion(new FillMaskConfig(null, null, null, null));
+                return FillMaskConfigUpdate.fromMap(valueMap);
+            } else if (configMap.containsKey(NerConfig.NAME)) {
+                checkSupportedVersion(new NerConfig(null, null, null, null));
+                return NerConfigUpdate.fromMap(valueMap);
+            } else if (configMap.containsKey(PassThroughConfig.NAME)) {
+                checkSupportedVersion(new PassThroughConfig(null, null, null));
+                return PassThroughConfigUpdate.fromMap(valueMap);
             } else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
                 checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
                 return RegressionConfigUpdate.fromMap(valueMap);
+            } else if (configMap.containsKey(TextClassificationConfig.NAME)) {
+                checkSupportedVersion(new TextClassificationConfig(null, null, List.of("meeting", "requirements"), null, null));
+                return TextClassificationConfigUpdate.fromMap(valueMap);
+            } else if (configMap.containsKey(TextEmbeddingConfig.NAME)) {
+                checkSupportedVersion(new TextEmbeddingConfig(null, null, null));
+                return TextEmbeddingConfigUpdate.fromMap(valueMap);
             } else if (configMap.containsKey(ZeroShotClassificationConfig.NAME)) {
-                checkSupportedVersion(new ZeroShotClassificationConfig(List.of("unused"), null, null, null, null, null));
+                checkSupportedVersion(new ZeroShotClassificationConfig(List.of("unused"), null, null, null, null, null, null));
                 return ZeroShotClassificationConfigUpdate.fromMap(valueMap);
             }
+            // TODO missing update types
             else {
                 throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
                     configMap.keySet(),

+ 8 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -23,8 +23,6 @@ import java.util.List;
 
 public class FillMaskProcessor implements NlpTask.Processor {
 
-    private static final int NUM_RESULTS = 5;
-
     private final NlpTask.RequestBuilder requestBuilder;
 
     FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
@@ -57,11 +55,14 @@ public class FillMaskProcessor implements NlpTask.Processor {
 
     @Override
     public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
-        return this::processResult;
+        if (config instanceof FillMaskConfig) {
+            return (tokenization, result) -> processResult(tokenization, result, ((FillMaskConfig)config).getNumTopClasses());
+        } else {
+            return (tokenization, result) -> processResult(tokenization, result, FillMaskConfig.DEFAULT_NUM_RESULTS);
+        }
     }
 
-    InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
-
+    InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult, int numResults) {
         if (tokenization.getTokenizations().isEmpty() ||
             tokenization.getTokenizations().get(0).getTokens().length == 0) {
             return new FillMaskResults(Collections.emptyList());
@@ -71,8 +72,8 @@ public class FillMaskProcessor implements NlpTask.Processor {
         // TODO - process all results in the batch
         double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
 
-        NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(NUM_RESULTS, normalizedScores);
-        List<FillMaskResults.Prediction> results = new ArrayList<>(NUM_RESULTS);
+        NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(numResults, normalizedScores);
+        List<FillMaskResults.Prediction> results = new ArrayList<>(numResults);
         for (NlpHelpers.ScoreAndIndex scoreAndIndex : scoreAndIndices) {
             String predictedToken = tokenization.getFromVocab(scoreAndIndex.index);
             String sequence = tokenization.getTokenizations().get(0).getInput().replace(BertTokenizer.MASK_TOKEN, predictedToken);

+ 2 - 28
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -7,8 +7,6 @@
 
 package org.elasticsearch.xpack.ml.inference.nlp;
 
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@@ -21,7 +19,6 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
 
 import java.util.Comparator;
 import java.util.List;
-import java.util.Locale;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -34,11 +31,7 @@ public class TextClassificationProcessor implements NlpTask.Processor {
     TextClassificationProcessor(NlpTokenizer tokenizer, TextClassificationConfig config) {
         this.requestBuilder = tokenizer.requestBuilder();
         List<String> classLabels = config.getClassificationLabels();
-        if (classLabels == null || classLabels.isEmpty()) {
-            this.classLabels = new String[] {"negative", "positive"};
-        } else {
-            this.classLabels = classLabels.toArray(String[]::new);
-        }
+        this.classLabels = classLabels.toArray(String[]::new);
         // negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size
         // This is a nice way of setting the value
         this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
@@ -46,26 +39,7 @@ public class TextClassificationProcessor implements NlpTask.Processor {
     }
 
     private void validate() {
-        if (classLabels.length < 2) {
-            throw new ValidationException().addValidationError(
-                String.format(
-                    Locale.ROOT,
-                    "Text classification requires at least 2 [%s]. Invalid labels [%s]",
-                    TextClassificationConfig.CLASSIFICATION_LABELS,
-                    Strings.arrayToCommaDelimitedString(classLabels)
-                )
-            );
-        }
-        if (numTopClasses == 0) {
-            throw new ValidationException().addValidationError(
-                String.format(
-                    Locale.ROOT,
-                    "Text classification requires at least 1 [%s]; provided [%d]",
-                    TextClassificationConfig.NUM_TOP_CLASSES,
-                    numTopClasses
-                )
-            );
-        }
+        // validation occurs in TextClassificationConfig
     }
 
     @Override

+ 9 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -24,7 +24,8 @@ import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.hasSize;
 import static org.mockito.Mockito.mock;
 
-public class FillMaskProcessorTests extends ESTestCase {
+public class
+FillMaskProcessorTests extends ESTestCase {
 
     public void testProcessResults() {
         // only the scores of the MASK index array
@@ -48,11 +49,11 @@ public class FillMaskProcessorTests extends ESTestCase {
         TokenizationResult tokenization = new TokenizationResult(vocab);
         tokenization.addTokenization(input, tokens, tokenIds, tokenMap);
 
-        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
+        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
 
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
-        FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null));
-        assertThat(result.getPredictions(), hasSize(5));
+        FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 0L, null), 4);
+        assertThat(result.getPredictions(), hasSize(4));
         FillMaskResults.Prediction prediction = result.getPredictions().get(0);
         assertEquals("France", prediction.getToken());
         assertEquals("The capital of France is Paris", prediction.getSequence());
@@ -70,10 +71,10 @@ public class FillMaskProcessorTests extends ESTestCase {
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
         tokenization.addTokenization("", new String[]{}, new int[] {}, new int[] {});
 
-        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
+        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
         PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][]{{{}}}, 0L, null);
-        FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult);
+        FillMaskResults result = (FillMaskResults) processor.processResult(tokenization, pyTorchResult, 5);
 
         assertThat(result.getPredictions(), empty());
     }
@@ -81,7 +82,7 @@ public class FillMaskProcessorTests extends ESTestCase {
     public void testValidate_GivenMissingMaskToken() {
         List<String> input = List.of("The capital of France is Paris");
 
-        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
+        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
 
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
@@ -93,7 +94,7 @@ public class FillMaskProcessorTests extends ESTestCase {
     public void testProcessResults_GivenMultipleMaskTokens() {
         List<String> input = List.of("The capital of [MASK] is [MASK]");
 
-        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null);
+        FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
         FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
 
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class,

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -64,7 +64,7 @@ public class NerProcessorTests extends ESTestCase {
         };
 
         List<String> classLabels = Arrays.stream(tags).map(NerProcessor.IobTag::toString).collect(Collectors.toList());
-        NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
+        NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels, null);
 
         ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
         assertThat(ve.getMessage(),
@@ -73,7 +73,7 @@ public class NerProcessorTests extends ESTestCase {
 
     public void testValidate_NotAEntityLabel() {
         List<String> classLabels = List.of("foo", NerProcessor.IobTag.B_MISC.toString());
-        NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels);
+        NerConfig nerConfig = new NerConfig(new VocabularyConfig("test-index"), null, classLabels, null);
 
         ValidationException ve = expectThrows(ValidationException.class, () -> new NerProcessor(mock(BertTokenizer.class), nerConfig));
         assertThat(ve.getMessage(), containsString("classification label [foo] is not an entity I-O-B tag"));

+ 6 - 32
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -7,7 +7,6 @@
 
 package org.elasticsearch.xpack.ml.inference.nlp;
 
-import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
@@ -25,7 +24,6 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 
-import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.mockito.Mockito.mock;
@@ -33,7 +31,9 @@ import static org.mockito.Mockito.mock;
 public class TextClassificationProcessorTests extends ESTestCase {
 
     public void testInvalidResult() {
-        TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
+        TextClassificationConfig config = new TextClassificationConfig(
+            new VocabularyConfig("test-index"), null, List.of("a", "b"), null, null);
+
         TextClassificationProcessor processor = new TextClassificationProcessor(mock(BertTokenizer.class), config);
         {
             PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
@@ -62,7 +62,9 @@ public class TextClassificationProcessorTests extends ESTestCase {
             ),
             new BertTokenization(null, null, 512));
 
-        TextClassificationConfig config = new TextClassificationConfig(new VocabularyConfig("test-index"), null, null, null);
+        TextClassificationConfig config = new TextClassificationConfig(
+            new VocabularyConfig("test-index"), null, List.of("a", "b"), null, null);
+
         TextClassificationProcessor processor = new TextClassificationProcessor(tokenizer, config);
 
         NlpTask.Request request = processor.getRequestBuilder(config).buildRequest(List.of("Elasticsearch fun"), "request1");
@@ -74,32 +76,4 @@ public class TextClassificationProcessorTests extends ESTestCase {
         assertEquals(Arrays.asList(3, 0, 1, 2, 4), ((List<List<Integer>>)jsonDocAsMap.get("tokens")).get(0));
         assertEquals(Arrays.asList(1, 1, 1, 1, 1), ((List<List<Integer>>)jsonDocAsMap.get("arg_1")).get(0));
     }
-
-    public void testValidate() {
-        ValidationException validationException = expectThrows(
-            ValidationException.class,
-            () -> new TextClassificationProcessor(
-                mock(BertTokenizer.class),
-                new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("too few"), null)
-            )
-        );
-
-        assertThat(
-            validationException.getMessage(),
-            containsString("Text classification requires at least 2 [classification_labels]. Invalid labels [too few]")
-        );
-
-        validationException = expectThrows(
-            ValidationException.class,
-            () -> new TextClassificationProcessor(
-                mock(BertTokenizer.class),
-                new TextClassificationConfig(new VocabularyConfig("test-index"), null, List.of("class", "labels"), 0)
-            )
-        );
-
-        assertThat(
-            validationException.getMessage(),
-            containsString("Text classification requires at least 1 [num_top_classes]; provided [0]")
-        );
-    }
 }

+ 1 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessorTests.java

@@ -43,6 +43,7 @@ public class ZeroShotClassificationProcessorTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         ZeroShotClassificationProcessor processor = new ZeroShotClassificationProcessor(tokenizer, config);