1
0
Эх сурвалжийг харах

[ML] PyTorch Sequence Classification (Sentiment Analysis) task (#73764)

Adds the Sequence Classification (Sentiment Analysis) task which given
some input text returns a positive_score and negative_score. These 
values are softmax normalised.
David Kyle 4 жил өмнө
parent
commit
41719d64c8

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -130,7 +130,9 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
             results.toXContent(builder, params);
+            builder.endObject();
             return builder;
         }
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java

@@ -42,7 +42,7 @@ public class FillMaskResults implements InferenceResults {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startArray();
+        builder.startArray(DEFAULT_RESULTS_FIELD);
         for (Prediction prediction : predictions) {
             prediction.toXContent(builder, params);
         }

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java

@@ -37,7 +37,7 @@ public class NerResults implements InferenceResults {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startArray();
+        builder.startArray("entities");
         for (EntityGroup entity : entityGroups) {
             entity.toXContent(builder, params);
         }

+ 0 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java

@@ -40,9 +40,7 @@ public class PyTorchPassThroughResults implements InferenceResults {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
         builder.field(INFERENCE.getPreferredName(), inference);
-        builder.endObject();
         return builder;
     }
 

+ 91 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SentimentAnalysisResults.java

@@ -0,0 +1,91 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+public class SentimentAnalysisResults implements InferenceResults {
+
+    public static final String NAME = "sentiment_analysis_result";
+
+    static final String POSITIVE_SCORE = "positive_score";
+    static final String NEGATIVE_SCORE = "negative_score";
+
+    private final double positiveScore;
+    private final double negativeScore;
+
+    public SentimentAnalysisResults(double positiveScore, double negativeScore) {
+        this.positiveScore = positiveScore;
+        this.negativeScore = negativeScore;
+    }
+
+    public SentimentAnalysisResults(StreamInput in) throws IOException {
+        positiveScore = in.readDouble();
+        negativeScore = in.readDouble();
+    }
+
+    public double getPositiveScore() {
+        return positiveScore;
+    }
+
+    public double getNegativeScore() {
+        return negativeScore;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(POSITIVE_SCORE, positiveScore);
+        builder.field(NEGATIVE_SCORE, negativeScore);
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeDouble(positiveScore);
+        out.writeDouble(negativeScore);
+    }
+
+    @Override
+    public Map<String, Object> asMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        map.put(POSITIVE_SCORE, positiveScore);
+        map.put(NEGATIVE_SCORE, negativeScore);
+        return map;
+    }
+
+    @Override
+    public Object predictedValue() {
+        return positiveScore;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        SentimentAnalysisResults that = (SentimentAnalysisResults) o;
+        return Double.compare(that.positiveScore, positiveScore) == 0 &&
+            Double.compare(that.negativeScore, negativeScore) == 0;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(positiveScore, negativeScore);
+    }
+}

+ 38 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/SentimentAnalysisResultsTests.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.util.Map;
+
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class SentimentAnalysisResultsTests extends AbstractWireSerializingTestCase<SentimentAnalysisResults> {
+    @Override
+    protected Writeable.Reader<SentimentAnalysisResults> instanceReader() {
+        return SentimentAnalysisResults::new;
+    }
+
+    @Override
+    protected SentimentAnalysisResults createTestInstance() {
+        return new SentimentAnalysisResults(randomDouble(), randomDouble());
+    }
+
+    public void testAsMap() {
+        SentimentAnalysisResults testInstance = createTestInstance();
+        Map<String, Object> asMap = testInstance.asMap();
+        assertThat(asMap.keySet(), hasSize(2));
+        assertThat(testInstance.getPositiveScore(),
+            closeTo((Double)asMap.get(SentimentAnalysisResults.POSITIVE_SCORE), 0.0001));
+        assertThat(testInstance.getNegativeScore(),
+            closeTo((Double)asMap.get(SentimentAnalysisResults.NEGATIVE_SCORE), 0.0001));
+    }
+}

+ 76 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessor.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.SentimentAnalysisResults;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+public class SentimentAnalysisProcessor implements NlpTask.Processor {
+
+    private final BertTokenizer tokenizer;
+
+    SentimentAnalysisProcessor(BertTokenizer tokenizer) {
+        this.tokenizer = tokenizer;
+    }
+    @Override
+    public void validateInputs(String inputs) {
+        // nothing to validate
+    }
+
+    @Override
+    public NlpTask.RequestBuilder getRequestBuilder() {
+        return this::buildRequest;
+    }
+
+    BytesReference buildRequest(String input, String requestId) throws IOException {
+        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize(input);
+        return jsonRequest(tokenization.getTokenIds(), requestId);
+    }
+
+    @Override
+    public NlpTask.ResultProcessor getResultProcessor() {
+        return this::processResult;
+    }
+
+    InferenceResults processResult(PyTorchResult pyTorchResult) {
+        if (pyTorchResult.getInferenceResult().length < 1) {
+            return new WarningInferenceResults("Sentiment analysis result has no data");
+        }
+
+        if (pyTorchResult.getInferenceResult()[0].length < 2) {
+            return new WarningInferenceResults("Expected 2 values in sentiment analysis result");
+        }
+
+        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0]);
+        return new SentimentAnalysisResults(normalizedScores[1], normalizedScores[0]);
+    }
+
+    static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {
+        XContentBuilder builder = XContentFactory.jsonBuilder();
+        builder.startObject();
+        builder.field(BertRequestBuilder.REQUEST_ID, requestId);
+        builder.array(BertRequestBuilder.TOKENS, tokens);
+
+        int[] inputMask = new int[tokens.length];
+        Arrays.fill(inputMask, 1);
+        builder.array(BertRequestBuilder.ARG1, inputMask);
+        builder.endObject();
+
+        // BytesReference.bytes closes the builder
+        return BytesReference.bytes(builder);
+    }
+}

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TaskType.java

@@ -19,6 +19,11 @@ public enum TaskType {
             return new NerProcessor(tokenizer);
         }
     },
+    SENTIMENT_ANALYSIS {
+        public NlpTask.Processor createProcessor(BertTokenizer tokenizer) throws IOException {
+            return new SentimentAnalysisProcessor(tokenizer);
+        }
+    },
     FILL_MASK {
         public NlpTask.Processor createProcessor(BertTokenizer tokenizer) throws IOException {
             return new FillMaskProcessor(tokenizer);

+ 62 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessorTests.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.nlp;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.Mockito.mock;
+
+public class SentimentAnalysisProcessorTests extends ESTestCase {
+
+    public void testInvalidResult() {
+        SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(mock(BertTokenizer.class));
+        {
+            PyTorchResult torchResult = new PyTorchResult("foo", new double[][]{}, null);
+            InferenceResults inferenceResults = processor.processResult(torchResult);
+            assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
+            assertEquals("Sentiment analysis result has no data",
+                ((WarningInferenceResults) inferenceResults).getWarning());
+        }
+        {
+            PyTorchResult torchResult = new PyTorchResult("foo", new double[][]{{1.0}}, null);
+            InferenceResults inferenceResults = processor.processResult(torchResult);
+            assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
+            assertEquals("Expected 2 values in sentiment analysis result",
+                ((WarningInferenceResults)inferenceResults).getWarning());
+        }
+    }
+
+    public void testBuildRequest() throws IOException {
+        BertTokenizer tokenizer = BertTokenizer.builder(
+            Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build();
+
+        SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(tokenizer);
+
+        BytesReference bytesReference = processor.buildRequest("Elasticsearch fun", "request1");
+
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
+
+        assertThat(jsonDocAsMap.keySet(), hasSize(3));
+        assertEquals("request1", jsonDocAsMap.get("request_id"));
+        assertEquals(Arrays.asList(3, 0, 1, 2, 4), jsonDocAsMap.get("tokens"));
+        assertEquals(Arrays.asList(1, 1, 1, 1, 1), jsonDocAsMap.get("arg_1"));
+    }
+ }