瀏覽代碼

[ML] update truncation default & adding field output when input is truncated (#79942)

This commit makes the two following changes (along with some
refactoring)  - Nlp results will now indicate if the input was truncated
or not  - The default truncation is now `none` instead of `first`
Benjamin Trent 4 年之前
父節點
當前提交
375fc779b4
共有 25 個文件被更改,包括 429 次插入113 次删除
  1. 1 1
      docs/reference/ml/ml-shared.asciidoc
  2. 8 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  3. 11 26
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java
  4. 12 12
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java
  5. 129 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java
  6. 74 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java
  7. 9 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java
  8. 11 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java
  9. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java
  10. 9 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java
  11. 10 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java
  12. 81 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java
  13. 7 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java
  14. 7 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java
  15. 16 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  16. 2 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  17. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  18. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  19. 3 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  20. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
  21. 3 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
  22. 7 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java
  23. 13 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java
  24. 3 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java
  25. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

+ 1 - 1
docs/reference/ml/ml-shared.asciidoc

@@ -927,7 +927,7 @@ end::inference-config-nlp-tokenization-bert-do-lower-case[]
 
 
 tag::inference-config-nlp-tokenization-bert-truncate[]
 tag::inference-config-nlp-tokenization-bert-truncate[]
 Indicates how tokens are truncated when they exceed `max_sequence_length`.
 Indicates how tokens are truncated when they exceed `max_sequence_length`.
-The default value is `first`.
+The default value is `none`.
 +
 +
 --
 --
 * `none`: No truncation occurs; the inference request receives an error.
 * `none`: No truncation occurs; the inference request receives an error.

+ 8 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
 import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
+import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
@@ -498,7 +499,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
             new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
         );
         );
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));
-
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                InferenceResults.class,
+                NlpClassificationInferenceResults.NAME,
+                NlpClassificationInferenceResults::new
+            )
+        );
         // Inference Configs
         // Inference Configs
         namedWriteables.add(
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)
             new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)

+ 11 - 26
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java

@@ -10,41 +10,27 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
 
 
 import java.io.IOException;
 import java.io.IOException;
-import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
 
 
-public class FillMaskResults extends ClassificationInferenceResults {
+public class FillMaskResults extends NlpClassificationInferenceResults {
 
 
     public static final String NAME = "fill_mask_result";
     public static final String NAME = "fill_mask_result";
 
 
     private final String predictedSequence;
     private final String predictedSequence;
 
 
     public FillMaskResults(
     public FillMaskResults(
-        double value,
         String classificationLabel,
         String classificationLabel,
         String predictedSequence,
         String predictedSequence,
         List<TopClassEntry> topClasses,
         List<TopClassEntry> topClasses,
-        String topNumClassesField,
         String resultsField,
         String resultsField,
-        Double predictionProbability
+        Double predictionProbability,
+        boolean isTruncated
     ) {
     ) {
-        super(
-            value,
-            classificationLabel,
-            topClasses,
-            List.of(),
-            topNumClassesField,
-            resultsField,
-            PredictionFieldType.STRING,
-            0,
-            predictionProbability,
-            null
-        );
+        super(classificationLabel, topClasses, resultsField, predictionProbability, isTruncated);
         this.predictedSequence = predictedSequence;
         this.predictedSequence = predictedSequence;
     }
     }
 
 
@@ -54,8 +40,8 @@ public class FillMaskResults extends ClassificationInferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        super.writeTo(out);
+    public void doWriteTo(StreamOutput out) throws IOException {
+        super.doWriteTo(out);
         out.writeString(predictedSequence);
         out.writeString(predictedSequence);
     }
     }
 
 
@@ -64,11 +50,9 @@ public class FillMaskResults extends ClassificationInferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public Map<String, Object> asMap() {
-        Map<String, Object> map = new LinkedHashMap<>();
+    void addMapFields(Map<String, Object> map) {
+        super.addMapFields(map);
         map.put(resultsField + "_sequence", predictedSequence);
         map.put(resultsField + "_sequence", predictedSequence);
-        map.putAll(super.asMap());
-        return map;
     }
     }
 
 
     @Override
     @Override
@@ -77,8 +61,9 @@ public class FillMaskResults extends ClassificationInferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        return super.toXContent(builder, params).field(resultsField + "_sequence", predictedSequence);
+    public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        super.doXContentBody(builder, params);
+        builder.field(resultsField + "_sequence", predictedSequence);
     }
     }
 
 
     @Override
     @Override

+ 12 - 12
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java

@@ -20,7 +20,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 
 
-public class NerResults implements InferenceResults {
+public class NerResults extends NlpInferenceResults {
 
 
     public static final String NAME = "ner_result";
     public static final String NAME = "ner_result";
     public static final String ENTITY_FIELD = "entities";
     public static final String ENTITY_FIELD = "entities";
@@ -30,27 +30,28 @@ public class NerResults implements InferenceResults {
 
 
     private final List<EntityGroup> entityGroups;
     private final List<EntityGroup> entityGroups;
 
 
-    public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups) {
+    public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups, boolean isTruncated) {
+        super(isTruncated);
         this.entityGroups = Objects.requireNonNull(entityGroups);
         this.entityGroups = Objects.requireNonNull(entityGroups);
         this.resultsField = Objects.requireNonNull(resultsField);
         this.resultsField = Objects.requireNonNull(resultsField);
         this.annotatedResult = Objects.requireNonNull(annotatedResult);
         this.annotatedResult = Objects.requireNonNull(annotatedResult);
     }
     }
 
 
     public NerResults(StreamInput in) throws IOException {
     public NerResults(StreamInput in) throws IOException {
+        super(in);
         entityGroups = in.readList(EntityGroup::new);
         entityGroups = in.readList(EntityGroup::new);
         resultsField = in.readString();
         resultsField = in.readString();
         annotatedResult = in.readString();
         annotatedResult = in.readString();
     }
     }
 
 
     @Override
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+    void doXContentBody(XContentBuilder builder, Params params) throws IOException {
         builder.field(resultsField, annotatedResult);
         builder.field(resultsField, annotatedResult);
         builder.startArray("entities");
         builder.startArray("entities");
         for (EntityGroup entity : entityGroups) {
         for (EntityGroup entity : entityGroups) {
             entity.toXContent(builder, params);
             entity.toXContent(builder, params);
         }
         }
         builder.endArray();
         builder.endArray();
-        return builder;
     }
     }
 
 
     @Override
     @Override
@@ -59,18 +60,16 @@ public class NerResults implements InferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public void writeTo(StreamOutput out) throws IOException {
+    void doWriteTo(StreamOutput out) throws IOException {
         out.writeList(entityGroups);
         out.writeList(entityGroups);
         out.writeString(resultsField);
         out.writeString(resultsField);
         out.writeString(annotatedResult);
         out.writeString(annotatedResult);
     }
     }
 
 
     @Override
     @Override
-    public Map<String, Object> asMap() {
-        Map<String, Object> map = new LinkedHashMap<>();
+    void addMapFields(Map<String, Object> map) {
         map.put(resultsField, annotatedResult);
         map.put(resultsField, annotatedResult);
         map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
         map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
-        return map;
     }
     }
 
 
     @Override
     @Override
@@ -95,15 +94,16 @@ public class NerResults implements InferenceResults {
     public boolean equals(Object o) {
     public boolean equals(Object o) {
         if (this == o) return true;
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
         NerResults that = (NerResults) o;
         NerResults that = (NerResults) o;
-        return Objects.equals(entityGroups, that.entityGroups)
-            && Objects.equals(resultsField, that.resultsField)
-            && Objects.equals(annotatedResult, that.annotatedResult);
+        return Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(annotatedResult, that.annotatedResult)
+            && Objects.equals(entityGroups, that.entityGroups);
     }
     }
 
 
     @Override
     @Override
     public int hashCode() {
     public int hashCode() {
-        return Objects.hash(entityGroups, resultsField, annotatedResult);
+        return Objects.hash(super.hashCode(), resultsField, annotatedResult, entityGroups);
     }
     }
 
 
     public static class EntityGroup implements ToXContentObject, Writeable {
     public static class EntityGroup implements ToXContentObject, Writeable {

+ 129 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java

@@ -0,0 +1,129 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+public class NlpClassificationInferenceResults extends NlpInferenceResults {
+
+    public static final String NAME = "nlp_classification";
+
+    // Accessed in sub-classes
+    protected final String resultsField;
+    private final String classificationLabel;
+    private final Double predictionProbability;
+    private final List<TopClassEntry> topClasses;
+
+    public NlpClassificationInferenceResults(
+        String classificationLabel,
+        List<TopClassEntry> topClasses,
+        String resultsField,
+        Double predictionProbability,
+        boolean isTruncated
+    ) {
+        super(isTruncated);
+        this.classificationLabel = Objects.requireNonNull(classificationLabel);
+        this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
+        this.resultsField = resultsField;
+        this.predictionProbability = predictionProbability;
+    }
+
+    public NlpClassificationInferenceResults(StreamInput in) throws IOException {
+        super(in);
+        this.classificationLabel = in.readString();
+        this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
+        this.resultsField = in.readString();
+        this.predictionProbability = in.readOptionalDouble();
+    }
+
+    public String getClassificationLabel() {
+        return classificationLabel;
+    }
+
+    public List<TopClassEntry> getTopClasses() {
+        return topClasses;
+    }
+
+    @Override
+    public void doWriteTo(StreamOutput out) throws IOException {
+        out.writeString(classificationLabel);
+        out.writeCollection(topClasses);
+        out.writeString(resultsField);
+        out.writeOptionalDouble(predictionProbability);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
+        NlpClassificationInferenceResults that = (NlpClassificationInferenceResults) o;
+        return Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(classificationLabel, that.classificationLabel)
+            && Objects.equals(predictionProbability, that.predictionProbability)
+            && Objects.equals(topClasses, that.topClasses);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), resultsField, classificationLabel, predictionProbability, topClasses);
+    }
+
+    public Double getPredictionProbability() {
+        return predictionProbability;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public Object predictedValue() {
+        return classificationLabel;
+    }
+
+    @Override
+    void addMapFields(Map<String, Object> map) {
+        map.put(resultsField, classificationLabel);
+        if (topClasses.isEmpty() == false) {
+            map.put(
+                NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD,
+                topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())
+            );
+        }
+        if (predictionProbability != null) {
+            map.put(PREDICTION_PROBABILITY, predictionProbability);
+        }
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        builder.field(resultsField, classificationLabel);
+        if (topClasses.size() > 0) {
+            builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
+        }
+        if (predictionProbability != null) {
+            builder.field(PREDICTION_PROBABILITY, predictionProbability);
+        }
+    }
+}

+ 74 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+abstract class NlpInferenceResults implements InferenceResults {
+
+    protected final boolean isTruncated;
+
+    NlpInferenceResults(boolean isTruncated) {
+        this.isTruncated = isTruncated;
+    }
+
+    NlpInferenceResults(StreamInput in) throws IOException {
+        this.isTruncated = in.readBoolean();
+    }
+
+    abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException;
+
+    abstract void doWriteTo(StreamOutput out) throws IOException;
+
+    abstract void addMapFields(Map<String, Object> map);
+
+    @Override
+    public final void writeTo(StreamOutput out) throws IOException {
+        out.writeBoolean(isTruncated);
+        doWriteTo(out);
+    }
+
+    @Override
+    public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        doXContentBody(builder, params);
+        if (isTruncated) {
+            builder.field("is_truncated", isTruncated);
+        }
+        return builder;
+    }
+
+    @Override
+    public final Map<String, Object> asMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        addMapFields(map);
+        if (isTruncated) {
+            map.put("is_truncated", isTruncated);
+        }
+        return map;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        NlpInferenceResults that = (NlpInferenceResults) o;
+        return isTruncated == that.isTruncated;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(isTruncated);
+    }
+}

+ 9 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java

@@ -13,23 +13,24 @@ import org.elasticsearch.xcontent.XContentBuilder;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Arrays;
-import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
 
 
-public class PyTorchPassThroughResults implements InferenceResults {
+public class PyTorchPassThroughResults extends NlpInferenceResults {
 
 
     public static final String NAME = "pass_through_result";
     public static final String NAME = "pass_through_result";
 
 
     private final double[][] inference;
     private final double[][] inference;
     private final String resultsField;
     private final String resultsField;
 
 
-    public PyTorchPassThroughResults(String resultsField, double[][] inference) {
+    public PyTorchPassThroughResults(String resultsField, double[][] inference, boolean isTruncated) {
+        super(isTruncated);
         this.inference = inference;
         this.inference = inference;
         this.resultsField = resultsField;
         this.resultsField = resultsField;
     }
     }
 
 
     public PyTorchPassThroughResults(StreamInput in) throws IOException {
     public PyTorchPassThroughResults(StreamInput in) throws IOException {
+        super(in);
         inference = in.readArray(StreamInput::readDoubleArray, double[][]::new);
         inference = in.readArray(StreamInput::readDoubleArray, double[][]::new);
         resultsField = in.readString();
         resultsField = in.readString();
     }
     }
@@ -39,9 +40,8 @@ public class PyTorchPassThroughResults implements InferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+    void doXContentBody(XContentBuilder builder, Params params) throws IOException {
         builder.field(resultsField, inference);
         builder.field(resultsField, inference);
-        return builder;
     }
     }
 
 
     @Override
     @Override
@@ -50,7 +50,7 @@ public class PyTorchPassThroughResults implements InferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public void writeTo(StreamOutput out) throws IOException {
+    public void doWriteTo(StreamOutput out) throws IOException {
         out.writeArray(StreamOutput::writeDoubleArray, inference);
         out.writeArray(StreamOutput::writeDoubleArray, inference);
         out.writeString(resultsField);
         out.writeString(resultsField);
     }
     }
@@ -61,10 +61,8 @@ public class PyTorchPassThroughResults implements InferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public Map<String, Object> asMap() {
-        Map<String, Object> map = new LinkedHashMap<>();
+    void addMapFields(Map<String, Object> map) {
         map.put(resultsField, inference);
         map.put(resultsField, inference);
-        return map;
     }
     }
 
 
     @Override
     @Override
@@ -76,12 +74,13 @@ public class PyTorchPassThroughResults implements InferenceResults {
     public boolean equals(Object o) {
     public boolean equals(Object o) {
         if (this == o) return true;
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
         PyTorchPassThroughResults that = (PyTorchPassThroughResults) o;
         PyTorchPassThroughResults that = (PyTorchPassThroughResults) o;
         return Arrays.deepEquals(inference, that.inference) && Objects.equals(resultsField, that.resultsField);
         return Arrays.deepEquals(inference, that.inference) && Objects.equals(resultsField, that.resultsField);
     }
     }
 
 
     @Override
     @Override
     public int hashCode() {
     public int hashCode() {
-        return Objects.hash(Arrays.deepHashCode(inference), resultsField);
+        return Objects.hash(super.hashCode(), resultsField, Arrays.deepHashCode(inference));
     }
     }
 }
 }

+ 11 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java

@@ -13,23 +13,24 @@ import org.elasticsearch.xcontent.XContentBuilder;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Arrays;
-import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
 
 
-public class TextEmbeddingResults implements InferenceResults {
+public class TextEmbeddingResults extends NlpInferenceResults {
 
 
     public static final String NAME = "text_embedding_result";
     public static final String NAME = "text_embedding_result";
 
 
     private final String resultsField;
     private final String resultsField;
     private final double[] inference;
     private final double[] inference;
 
 
-    public TextEmbeddingResults(String resultsField, double[] inference) {
+    public TextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) {
+        super(isTruncated);
         this.inference = inference;
         this.inference = inference;
         this.resultsField = resultsField;
         this.resultsField = resultsField;
     }
     }
 
 
     public TextEmbeddingResults(StreamInput in) throws IOException {
     public TextEmbeddingResults(StreamInput in) throws IOException {
+        super(in);
         inference = in.readDoubleArray();
         inference = in.readDoubleArray();
         resultsField = in.readString();
         resultsField = in.readString();
     }
     }
@@ -43,8 +44,8 @@ public class TextEmbeddingResults implements InferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        return builder.field(resultsField, inference);
+    void doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        builder.field(resultsField, inference);
     }
     }
 
 
     @Override
     @Override
@@ -53,16 +54,14 @@ public class TextEmbeddingResults implements InferenceResults {
     }
     }
 
 
     @Override
     @Override
-    public void writeTo(StreamOutput out) throws IOException {
+    void doWriteTo(StreamOutput out) throws IOException {
         out.writeDoubleArray(inference);
         out.writeDoubleArray(inference);
         out.writeString(resultsField);
         out.writeString(resultsField);
     }
     }
 
 
     @Override
     @Override
-    public Map<String, Object> asMap() {
-        Map<String, Object> map = new LinkedHashMap<>();
+    void addMapFields(Map<String, Object> map) {
         map.put(resultsField, inference);
         map.put(resultsField, inference);
-        return map;
     }
     }
 
 
     @Override
     @Override
@@ -74,12 +73,13 @@ public class TextEmbeddingResults implements InferenceResults {
     public boolean equals(Object o) {
     public boolean equals(Object o) {
         if (this == o) return true;
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
         TextEmbeddingResults that = (TextEmbeddingResults) o;
         TextEmbeddingResults that = (TextEmbeddingResults) o;
-        return Arrays.equals(inference, that.inference) && Objects.equals(resultsField, that.resultsField);
+        return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference);
     }
     }
 
 
     @Override
     @Override
     public int hashCode() {
     public int hashCode() {
-        return Objects.hash(Arrays.hashCode(inference), resultsField);
+        return Objects.hash(super.hashCode(), resultsField, Arrays.hashCode(inference));
     }
     }
 }
 }

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java

@@ -47,7 +47,7 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl
     private static final int DEFAULT_MAX_SEQUENCE_LENGTH = 512;
     private static final int DEFAULT_MAX_SEQUENCE_LENGTH = 512;
     private static final boolean DEFAULT_DO_LOWER_CASE = false;
     private static final boolean DEFAULT_DO_LOWER_CASE = false;
     private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true;
     private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true;
-    private static final Truncate DEFAULT_TRUNCATION = Truncate.FIRST;
+    private static final Truncate DEFAULT_TRUNCATION = Truncate.NONE;
 
 
     static <T extends Tokenization> void declareCommonFields(ConstructingObjectParser<T, ?> parser) {
     static <T extends Tokenization> void declareCommonFields(ConstructingObjectParser<T, ?> parser) {
         parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE);
         parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE);

+ 9 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java

@@ -18,8 +18,10 @@ import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
 import static org.hamcrest.Matchers.nullValue;
 
 
 public class FillMaskResultsTests extends AbstractWireSerializingTestCase<FillMaskResults> {
 public class FillMaskResultsTests extends AbstractWireSerializingTestCase<FillMaskResults> {
@@ -36,13 +38,12 @@ public class FillMaskResultsTests extends AbstractWireSerializingTestCase<FillMa
             resultList.add(TopClassEntryTests.createRandomTopClassEntry());
             resultList.add(TopClassEntryTests.createRandomTopClassEntry());
         }
         }
         return new FillMaskResults(
         return new FillMaskResults(
-            0.0,
             randomAlphaOfLength(10),
             randomAlphaOfLength(10),
             randomAlphaOfLength(10),
             randomAlphaOfLength(10),
             resultList,
             resultList,
-            DEFAULT_TOP_CLASSES_RESULTS_FIELD,
             DEFAULT_RESULTS_FIELD,
             DEFAULT_RESULTS_FIELD,
-            randomDouble()
+            randomDouble(),
+            randomBoolean()
         );
         );
     }
     }
 
 
@@ -54,6 +55,11 @@ public class FillMaskResultsTests extends AbstractWireSerializingTestCase<FillMa
         assertThat(asMap.get(PREDICTION_PROBABILITY), equalTo(testInstance.getPredictionProbability()));
         assertThat(asMap.get(PREDICTION_PROBABILITY), equalTo(testInstance.getPredictionProbability()));
         assertThat(asMap.get(DEFAULT_RESULTS_FIELD + "_sequence"), equalTo(testInstance.getPredictedSequence()));
         assertThat(asMap.get(DEFAULT_RESULTS_FIELD + "_sequence"), equalTo(testInstance.getPredictedSequence()));
         List<Map<String, Object>> resultList = (List<Map<String, Object>>) asMap.get(DEFAULT_TOP_CLASSES_RESULTS_FIELD);
         List<Map<String, Object>> resultList = (List<Map<String, Object>>) asMap.get(DEFAULT_TOP_CLASSES_RESULTS_FIELD);
+        if (testInstance.isTruncated) {
+            assertThat(asMap.get("is_truncated"), is(true));
+        } else {
+            assertThat(asMap, not(hasKey("is_truncated")));
+        }
         if (testInstance.getTopClasses().size() == 0) {
         if (testInstance.getTopClasses().size() == 0) {
             assertThat(resultList, is(nullValue()));
             assertThat(resultList, is(nullValue()));
         } else {
         } else {

+ 10 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java

@@ -17,7 +17,10 @@ import java.util.stream.Stream;
 
 
 import static org.elasticsearch.xpack.core.ml.inference.results.NerResults.ENTITY_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.results.NerResults.ENTITY_FIELD;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 
 
 public class NerResultsTests extends InferenceResultsTestCase<NerResults> {
 public class NerResultsTests extends InferenceResultsTestCase<NerResults> {
     @Override
     @Override
@@ -40,7 +43,8 @@ public class NerResultsTests extends InferenceResultsTestCase<NerResults> {
                     randomIntBetween(-1, 5),
                     randomIntBetween(-1, 5),
                     randomIntBetween(5, 10)
                     randomIntBetween(5, 10)
                 )
                 )
-            ).limit(numEntities).collect(Collectors.toList())
+            ).limit(numEntities).collect(Collectors.toList()),
+            randomBoolean()
         );
         );
     }
     }
 
 
@@ -54,6 +58,11 @@ public class NerResultsTests extends InferenceResultsTestCase<NerResults> {
         }
         }
         assertThat(resultList, hasSize(testInstance.getEntityGroups().size()));
         assertThat(resultList, hasSize(testInstance.getEntityGroups().size()));
         assertThat(asMap.get(testInstance.getResultsField()), equalTo(testInstance.getAnnotatedResult()));
         assertThat(asMap.get(testInstance.getResultsField()), equalTo(testInstance.getAnnotatedResult()));
+        if (testInstance.isTruncated) {
+            assertThat(asMap.get("is_truncated"), is(true));
+        } else {
+            assertThat(asMap, not(hasKey("is_truncated")));
+        }
         for (int i = 0; i < testInstance.getEntityGroups().size(); i++) {
         for (int i = 0; i < testInstance.getEntityGroups().size(); i++) {
             NerResults.EntityGroup entity = testInstance.getEntityGroups().get(i);
             NerResults.EntityGroup entity = testInstance.getEntityGroups().get(i);
             Map<String, Object> map = resultList.get(i);
             Map<String, Object> map = resultList.get(i);

+ 81 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.ingest.IngestDocument;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
+import static org.hamcrest.Matchers.equalTo;
+
+public class NlpClassificationInferenceResultsTests extends InferenceResultsTestCase<NlpClassificationInferenceResults> {
+
+    public static NlpClassificationInferenceResults createRandomResults() {
+        return new NlpClassificationInferenceResults(
+            randomAlphaOfLength(10),
+            randomBoolean()
+                ? null
+                : Stream.generate(TopClassEntryTests::createRandomTopClassEntry)
+                    .limit(randomIntBetween(0, 10))
+                    .collect(Collectors.toList()),
+            randomAlphaOfLength(10),
+            randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false),
+            randomBoolean()
+        );
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testWriteResultsWithTopClasses() {
+        List<TopClassEntry> entries = Arrays.asList(
+            new TopClassEntry("foo", 0.7, 0.7),
+            new TopClassEntry("bar", 0.2, 0.2),
+            new TopClassEntry("baz", 0.1, 0.1)
+        );
+        NlpClassificationInferenceResults result = new NlpClassificationInferenceResults(
+            "foo",
+            entries,
+            "my_results",
+            0.7,
+            randomBoolean()
+        );
+        IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
+        writeResult(result, document, "result_field", "test");
+
+        List<?> list = document.getFieldValue("result_field.top_classes", List.class);
+        assertThat(list.size(), equalTo(3));
+
+        for (int i = 0; i < 3; i++) {
+            Map<String, Object> map = (Map<String, Object>) list.get(i);
+            assertThat(map, equalTo(entries.get(i).asValueMap()));
+        }
+
+        assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo"));
+    }
+
+    @Override
+    protected NlpClassificationInferenceResults createTestInstance() {
+        return createRandomResults();
+    }
+
+    @Override
+    protected Writeable.Reader<NlpClassificationInferenceResults> instanceReader() {
+        return NlpClassificationInferenceResults::new;
+    }
+
+    @Override
+    void assertFieldValues(NlpClassificationInferenceResults createdInstance, IngestDocument document, String resultsField) {
+        String path = resultsField + "." + createdInstance.getResultsField();
+        assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue()));
+    }
+}

+ 7 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java

@@ -14,6 +14,7 @@ import java.util.Map;
 
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
 
 
 public class PyTorchPassThroughResultsTests extends InferenceResultsTestCase<PyTorchPassThroughResults> {
 public class PyTorchPassThroughResultsTests extends InferenceResultsTestCase<PyTorchPassThroughResults> {
     @Override
     @Override
@@ -32,14 +33,18 @@ public class PyTorchPassThroughResultsTests extends InferenceResultsTestCase<PyT
             }
             }
         }
         }
 
 
-        return new PyTorchPassThroughResults(DEFAULT_RESULTS_FIELD, arr);
+        return new PyTorchPassThroughResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean());
     }
     }
 
 
     public void testAsMap() {
     public void testAsMap() {
         PyTorchPassThroughResults testInstance = createTestInstance();
         PyTorchPassThroughResults testInstance = createTestInstance();
         Map<String, Object> asMap = testInstance.asMap();
         Map<String, Object> asMap = testInstance.asMap();
-        assertThat(asMap.keySet(), hasSize(1));
+        int size = testInstance.isTruncated ? 2 : 1;
+        assertThat(asMap.keySet(), hasSize(size));
         assertArrayEquals(testInstance.getInference(), (double[][]) asMap.get(DEFAULT_RESULTS_FIELD));
         assertArrayEquals(testInstance.getInference(), (double[][]) asMap.get(DEFAULT_RESULTS_FIELD));
+        if (testInstance.isTruncated) {
+            assertThat(asMap.get("is_truncated"), is(true));
+        }
     }
     }
 
 
     @Override
     @Override

+ 7 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java

@@ -14,6 +14,7 @@ import java.util.Map;
 
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
 
 
 public class TextEmbeddingResultsTests extends InferenceResultsTestCase<TextEmbeddingResults> {
 public class TextEmbeddingResultsTests extends InferenceResultsTestCase<TextEmbeddingResults> {
     @Override
     @Override
@@ -29,14 +30,18 @@ public class TextEmbeddingResultsTests extends InferenceResultsTestCase<TextEmbe
             arr[i] = randomDouble();
             arr[i] = randomDouble();
         }
         }
 
 
-        return new TextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr);
+        return new TextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean());
     }
     }
 
 
     public void testAsMap() {
     public void testAsMap() {
         TextEmbeddingResults testInstance = createTestInstance();
         TextEmbeddingResults testInstance = createTestInstance();
         Map<String, Object> asMap = testInstance.asMap();
         Map<String, Object> asMap = testInstance.asMap();
-        assertThat(asMap.keySet(), hasSize(1));
+        int size = testInstance.isTruncated ? 2 : 1;
+        assertThat(asMap.keySet(), hasSize(size));
         assertArrayEquals(testInstance.getInference(), (double[]) asMap.get(DEFAULT_RESULTS_FIELD), 1e-10);
         assertArrayEquals(testInstance.getInference(), (double[]) asMap.get(DEFAULT_RESULTS_FIELD), 1e-10);
+        if (testInstance.isTruncated) {
+            assertThat(asMap.get("is_truncated"), is(true));
+        }
     }
     }
 
 
     @Override
     @Override

+ 16 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -238,7 +238,16 @@ public class DeploymentManager {
         }
         }
 
 
         final long requestId = requestIdCounter.getAndIncrement();
         final long requestId = requestIdCounter.getAndIncrement();
-        InferenceAction inferenceAction = new InferenceAction(requestId, timeout, processContext, config, doc, threadPool, listener);
+        InferenceAction inferenceAction = new InferenceAction(
+            task.getModelId(),
+            requestId,
+            timeout,
+            processContext,
+            config,
+            doc,
+            threadPool,
+            listener
+        );
         try {
         try {
             processContext.executorService.execute(inferenceAction);
             processContext.executorService.execute(inferenceAction);
         } catch (Exception e) {
         } catch (Exception e) {
@@ -247,6 +256,7 @@ public class DeploymentManager {
     }
     }
 
 
     static class InferenceAction extends AbstractRunnable {
     static class InferenceAction extends AbstractRunnable {
+        private final String modelId;
         private final long requestId;
         private final long requestId;
         private final TimeValue timeout;
         private final TimeValue timeout;
         private final Scheduler.Cancellable timeoutHandler;
         private final Scheduler.Cancellable timeoutHandler;
@@ -257,6 +267,7 @@ public class DeploymentManager {
         private final AtomicBoolean notified = new AtomicBoolean();
         private final AtomicBoolean notified = new AtomicBoolean();
 
 
         InferenceAction(
         InferenceAction(
+            String modelId,
             long requestId,
             long requestId,
             TimeValue timeout,
             TimeValue timeout,
             ProcessContext processContext,
             ProcessContext processContext,
@@ -265,6 +276,7 @@ public class DeploymentManager {
             ThreadPool threadPool,
             ThreadPool threadPool,
             ActionListener<InferenceResults> listener
             ActionListener<InferenceResults> listener
         ) {
         ) {
+            this.modelId = modelId;
             this.requestId = requestId;
             this.requestId = requestId;
             this.timeout = timeout;
             this.timeout = timeout;
             this.processContext = processContext;
             this.processContext = processContext;
@@ -321,6 +333,9 @@ public class DeploymentManager {
                 assert config instanceof NlpConfig;
                 assert config instanceof NlpConfig;
                 NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr);
                 NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr);
                 logger.trace(() -> "Inference Request " + request.processInput.utf8ToString());
                 logger.trace(() -> "Inference Request " + request.processInput.utf8ToString());
+                if (request.tokenization.anyTruncated()) {
+                    logger.debug("[{}] [{}] input truncated", modelId, requestId);
+                }
                 PyTorchResultProcessor.PendingResult pendingResult = processContext.getResultProcessor().registerRequest(requestIdStr);
                 PyTorchResultProcessor.PendingResult pendingResult = processContext.getResultProcessor().registerRequest(requestIdStr);
                 processContext.process.get().writeInferenceRequest(request.processInput);
                 processContext.process.get().writeInferenceRequest(request.processInput);
                 waitForResult(
                 waitForResult(

+ 2 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -24,7 +24,6 @@ import java.util.List;
 import java.util.Optional;
 import java.util.Optional;
 
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
-import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD;
 
 
 public class FillMaskProcessor implements NlpTask.Processor {
 public class FillMaskProcessor implements NlpTask.Processor {
 
 
@@ -100,16 +99,15 @@ public class FillMaskProcessor implements NlpTask.Processor {
             }
             }
         }
         }
         return new FillMaskResults(
         return new FillMaskResults(
-            scoreAndIndices[0].index,
             tokenization.getFromVocab(scoreAndIndices[0].index),
             tokenization.getFromVocab(scoreAndIndices[0].index),
             tokenization.getTokenizations()
             tokenization.getTokenizations()
                 .get(0)
                 .get(0)
                 .getInput()
                 .getInput()
                 .replace(BertTokenizer.MASK_TOKEN, tokenization.getFromVocab(scoreAndIndices[0].index)),
                 .replace(BertTokenizer.MASK_TOKEN, tokenization.getFromVocab(scoreAndIndices[0].index)),
             results,
             results,
-            DEFAULT_TOP_CLASSES_RESULTS_FIELD,
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
-            scoreAndIndices[0].score
+            scoreAndIndices[0].score,
+            tokenization.anyTruncated()
         );
         );
     }
     }
 }
 }

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -213,7 +213,12 @@ public class NerProcessor implements NlpTask.Processor {
                     ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT)
                     ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT)
                     : tokenization.getTokenizations().get(0).getInput()
                     : tokenization.getTokenizations().get(0).getInput()
             );
             );
-            return new NerResults(resultsField, buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities), entities);
+            return new NerResults(
+                resultsField,
+                buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities),
+                entities,
+                tokenization.anyTruncated()
+            );
         }
         }
 
 
         /**
         /**

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -53,7 +53,8 @@ public class PassThroughProcessor implements NlpTask.Processor {
         // TODO - process all results in the batch
         // TODO - process all results in the batch
         return new PyTorchPassThroughResults(
         return new PyTorchPassThroughResults(
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
-            pyTorchResult.getInferenceResult()[0]
+            pyTorchResult.getInferenceResult()[0],
+            tokenization.anyTruncated()
         );
         );
     }
     }
 }
 }

+ 3 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -7,12 +7,11 @@
 
 
 package org.elasticsearch.xpack.ml.inference.nlp;
 package org.elasticsearch.xpack.ml.inference.nlp;
 
 
-import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -26,7 +25,6 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.IntStream;
 
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
-import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD;
 
 
 public class TextClassificationProcessor implements NlpTask.Processor {
 public class TextClassificationProcessor implements NlpTask.Processor {
 
 
@@ -109,20 +107,15 @@ public class TextClassificationProcessor implements NlpTask.Processor {
             .mapToInt(i -> i)
             .mapToInt(i -> i)
             .toArray();
             .toArray();
 
 
-        return new ClassificationInferenceResults(
-            sortedIndices[0],
+        return new NlpClassificationInferenceResults(
             labels.get(sortedIndices[0]),
             labels.get(sortedIndices[0]),
             Arrays.stream(sortedIndices)
             Arrays.stream(sortedIndices)
                 .mapToObj(i -> new TopClassEntry(labels.get(i), normalizedScores[i]))
                 .mapToObj(i -> new TopClassEntry(labels.get(i), normalizedScores[i]))
                 .limit(numTopClasses)
                 .limit(numTopClasses)
                 .collect(Collectors.toList()),
                 .collect(Collectors.toList()),
-            List.of(),
-            DEFAULT_TOP_CLASSES_RESULTS_FIELD,
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
-            PredictionFieldType.STRING,
-            0,
             normalizedScores[sortedIndices[0]],
             normalizedScores[sortedIndices[0]],
-            null
+            tokenization.anyTruncated()
         );
         );
     }
     }
 }
 }

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -50,7 +50,8 @@ public class TextEmbeddingProcessor implements NlpTask.Processor {
         // TODO - process all results in the batch
         // TODO - process all results in the batch
         return new TextEmbeddingResults(
         return new TextEmbeddingResults(
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
-            pyTorchResult.getInferenceResult()[0][0]
+            pyTorchResult.getInferenceResult()[0][0],
+            tokenization.anyTruncated()
         );
         );
     }
     }
 }
 }

+ 3 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -8,12 +8,11 @@
 package org.elasticsearch.xpack.ml.inference.nlp;
 package org.elasticsearch.xpack.ml.inference.nlp;
 
 
 import org.elasticsearch.common.logging.LoggerMessageFormat;
 import org.elasticsearch.common.logging.LoggerMessageFormat;
-import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -31,7 +30,6 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.IntStream;
 
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
-import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD;
 
 
 public class ZeroShotClassificationProcessor implements NlpTask.Processor {
 public class ZeroShotClassificationProcessor implements NlpTask.Processor {
 
 
@@ -198,17 +196,12 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
                 .mapToInt(i -> i)
                 .mapToInt(i -> i)
                 .toArray();
                 .toArray();
 
 
-            return new ClassificationInferenceResults(
-                sortedIndices[0],
+            return new NlpClassificationInferenceResults(
                 labels[sortedIndices[0]],
                 labels[sortedIndices[0]],
                 Arrays.stream(sortedIndices).mapToObj(i -> new TopClassEntry(labels[i], normalizedScores[i])).collect(Collectors.toList()),
                 Arrays.stream(sortedIndices).mapToObj(i -> new TopClassEntry(labels[i], normalizedScores[i])).collect(Collectors.toList()),
-                List.of(),
-                DEFAULT_TOP_CLASSES_RESULTS_FIELD,
                 Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
                 Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),
-                PredictionFieldType.STRING,
-                0,
                 normalizedScores[sortedIndices[0]],
                 normalizedScores[sortedIndices[0]],
-                null
+                tokenization.anyTruncated()
             );
             );
         }
         }
     }
     }

+ 7 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java

@@ -118,10 +118,12 @@ public class BertTokenizer implements NlpTokenizer {
         List<WordPieceTokenizer.TokenAndId> wordPieceTokens = innerResult.v1();
         List<WordPieceTokenizer.TokenAndId> wordPieceTokens = innerResult.v1();
         List<Integer> tokenPositionMap = innerResult.v2();
         List<Integer> tokenPositionMap = innerResult.v2();
         int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
         int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
+        boolean isTruncated = false;
         if (numTokens > maxSequenceLength) {
         if (numTokens > maxSequenceLength) {
             switch (truncate) {
             switch (truncate) {
                 case FIRST:
                 case FIRST:
                 case SECOND:
                 case SECOND:
+                    isTruncated = true;
                     wordPieceTokens = wordPieceTokens.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
                     wordPieceTokens = wordPieceTokens.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength);
                     break;
                     break;
                 case NONE:
                 case NONE:
@@ -158,7 +160,7 @@ public class BertTokenizer implements NlpTokenizer {
             tokenMap[i] = SPECIAL_TOKEN_POSITION;
             tokenMap[i] = SPECIAL_TOKEN_POSITION;
         }
         }
 
 
-        return new TokenizationResult.Tokenization(seq, tokens, tokenIds, tokenMap);
+        return new TokenizationResult.Tokenization(seq, isTruncated, tokens, tokenIds, tokenMap);
     }
     }
 
 
     @Override
     @Override
@@ -175,9 +177,11 @@ public class BertTokenizer implements NlpTokenizer {
         // [CLS] seq1 [SEP] seq2 [SEP]
         // [CLS] seq1 [SEP] seq2 [SEP]
         int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3;
         int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3;
 
 
+        boolean isTruncated = false;
         if (numTokens > maxSequenceLength) {
         if (numTokens > maxSequenceLength) {
             switch (truncate) {
             switch (truncate) {
                 case FIRST:
                 case FIRST:
+                    isTruncated = true;
                     if (wordPieceTokenSeq2s.size() > maxSequenceLength - 3) {
                     if (wordPieceTokenSeq2s.size() > maxSequenceLength - 3) {
                         throw ExceptionsHelper.badRequestException(
                         throw ExceptionsHelper.badRequestException(
                             "Attempting truncation [{}] but input is too large for the second sequence. "
                             "Attempting truncation [{}] but input is too large for the second sequence. "
@@ -191,6 +195,7 @@ public class BertTokenizer implements NlpTokenizer {
                     wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq2s.size());
                     wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq2s.size());
                     break;
                     break;
                 case SECOND:
                 case SECOND:
+                    isTruncated = true;
                     if (wordPieceTokenSeq1s.size() > maxSequenceLength - 3) {
                     if (wordPieceTokenSeq1s.size() > maxSequenceLength - 3) {
                         throw ExceptionsHelper.badRequestException(
                         throw ExceptionsHelper.badRequestException(
                             "Attempting truncation [{}] but input is too large for the first sequence. "
                             "Attempting truncation [{}] but input is too large for the first sequence. "
@@ -245,15 +250,7 @@ public class BertTokenizer implements NlpTokenizer {
         tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
         tokenIds[i] = vocab.get(SEPARATOR_TOKEN);
         tokenMap[i] = SPECIAL_TOKEN_POSITION;
         tokenMap[i] = SPECIAL_TOKEN_POSITION;
 
 
-        // TODO handle seq1 truncation
-        if (tokenIds.length > maxSequenceLength) {
-            throw ExceptionsHelper.badRequestException(
-                "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
-                tokenIds.length,
-                maxSequenceLength
-            );
-        }
-        return new TokenizationResult.Tokenization(seq1 + seq2, tokens, tokenIds, tokenMap);
+        return new TokenizationResult.Tokenization(seq1 + seq2, isTruncated, tokens, tokenIds, tokenMap);
     }
     }
 
 
     private Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerTokenize(String seq) {
     private Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerTokenize(String seq) {

+ 13 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -21,6 +21,10 @@ public class TokenizationResult {
         this.maxLength = -1;
         this.maxLength = -1;
     }
     }
 
 
+    public boolean anyTruncated() {
+        return tokenizations.stream().anyMatch(Tokenization::isTruncated);
+    }
+
     public String getFromVocab(int tokenId) {
     public String getFromVocab(int tokenId) {
         return vocab.get(tokenId);
         return vocab.get(tokenId);
     }
     }
@@ -29,9 +33,9 @@ public class TokenizationResult {
         return tokenizations;
         return tokenizations;
     }
     }
 
 
-    public void addTokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) {
+    public void addTokenization(String input, boolean isTruncated, String[] tokens, int[] tokenIds, int[] tokenMap) {
         maxLength = Math.max(maxLength, tokenIds.length);
         maxLength = Math.max(maxLength, tokenIds.length);
-        tokenizations.add(new Tokenization(input, tokens, tokenIds, tokenMap));
+        tokenizations.add(new Tokenization(input, isTruncated, tokens, tokenIds, tokenMap));
     }
     }
 
 
     public void addTokenization(Tokenization tokenization) {
     public void addTokenization(Tokenization tokenization) {
@@ -49,14 +53,16 @@ public class TokenizationResult {
         private final String[] tokens;
         private final String[] tokens;
         private final int[] tokenIds;
         private final int[] tokenIds;
         private final int[] tokenMap;
         private final int[] tokenMap;
+        private final boolean truncated;
 
 
-        public Tokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) {
+        public Tokenization(String input, boolean truncated, String[] tokens, int[] tokenIds, int[] tokenMap) {
             assert tokens.length == tokenIds.length;
             assert tokens.length == tokenIds.length;
             assert tokenIds.length == tokenMap.length;
             assert tokenIds.length == tokenMap.length;
             this.inputSeqs = input;
             this.inputSeqs = input;
             this.tokens = tokens;
             this.tokens = tokens;
             this.tokenIds = tokenIds;
             this.tokenIds = tokenIds;
             this.tokenMap = tokenMap;
             this.tokenMap = tokenMap;
+            this.truncated = truncated;
         }
         }
 
 
         /**
         /**
@@ -91,5 +97,9 @@ public class TokenizationResult {
         public String getInput() {
         public String getInput() {
             return inputSeqs;
             return inputSeqs;
         }
         }
+
+        public boolean isTruncated() {
+            return truncated;
+        }
     }
     }
 }
 }

+ 3 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -53,6 +53,7 @@ public class DeploymentManagerTests extends ESTestCase {
 
 
         ListenerCounter listener = new ListenerCounter();
         ListenerCounter listener = new ListenerCounter();
         DeploymentManager.InferenceAction action = new DeploymentManager.InferenceAction(
         DeploymentManager.InferenceAction action = new DeploymentManager.InferenceAction(
+            "test-model",
             1,
             1,
             TimeValue.MAX_VALUE,
             TimeValue.MAX_VALUE,
             processContext,
             processContext,
@@ -72,6 +73,7 @@ public class DeploymentManagerTests extends ESTestCase {
         assertThat(listener.responseCounts, equalTo(1));
         assertThat(listener.responseCounts, equalTo(1));
 
 
         action = new DeploymentManager.InferenceAction(
         action = new DeploymentManager.InferenceAction(
+            "test-model",
             1,
             1,
             TimeValue.MAX_VALUE,
             TimeValue.MAX_VALUE,
             processContext,
             processContext,
@@ -91,6 +93,7 @@ public class DeploymentManagerTests extends ESTestCase {
         assertThat(listener.responseCounts, equalTo(1));
         assertThat(listener.responseCounts, equalTo(1));
 
 
         action = new DeploymentManager.InferenceAction(
         action = new DeploymentManager.InferenceAction(
+            "test-model",
             1,
             1,
             TimeValue.MAX_VALUE,
             TimeValue.MAX_VALUE,
             processContext,
             processContext,

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -50,7 +50,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         int[] tokenIds = new int[] { 0, 1, 2, 3, 4, 5 };
         int[] tokenIds = new int[] { 0, 1, 2, 3, 4, 5 };
 
 
         TokenizationResult tokenization = new TokenizationResult(vocab);
         TokenizationResult tokenization = new TokenizationResult(vocab);
-        tokenization.addTokenization(input, tokens, tokenIds, tokenMap);
+        tokenization.addTokenization(input, false, tokens, tokenIds, tokenMap);
 
 
         String resultsField = randomAlphaOfLength(10);
         String resultsField = randomAlphaOfLength(10);
         FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
         FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
@@ -73,7 +73,7 @@ public class FillMaskProcessorTests extends ESTestCase {
 
 
     public void testProcessResults_GivenMissingTokens() {
     public void testProcessResults_GivenMissingTokens() {
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
-        tokenization.addTokenization("", new String[] {}, new int[] {}, new int[] {});
+        tokenization.addTokenization("", false, new String[] {}, new int[] {}, new int[] {});
 
 
         PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null);
         PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null);
         assertThat(
         assertThat(