Browse Source

[ML] Fix concurrent usage of NLP processors (#76719)

This fixes a problem with using NLP processors concurrently.
In particular, the `BertRequestBuilder` holds a local reference
to the tokenization result which is needed for processing the
result. However, this is not thread-safe. Multiple concurrent
calls will override the cached tokenization result and, consuquently,
the final results will be wrong.

This commit fixes this by making `BertRequestBuilder` stateless
and caching the tokenization result in the stack in the method
that's waiting for the result to come through from the process.
Dimitris Athanasiou 4 years ago
parent
commit
3df39a616d

+ 9 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -17,7 +17,6 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.client.Client;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -42,6 +41,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
@@ -158,7 +158,6 @@ public class DeploymentManager {
     }
 
     Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
-
         try (InputStream stream = hit.getSourceRef().streamInput();
              XContentParser parser = XContentFactory.xContent(XContentType.JSON)
                  .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
@@ -226,11 +225,12 @@ public class DeploymentManager {
                     String text = NlpTask.extractInput(processContext.modelInput.get(), doc);
                     NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
                     processor.validateInputs(text);
-                    BytesReference request = processor.getRequestBuilder().buildRequest(text, requestId);
-                    logger.trace(() -> "Inference Request "+ request.utf8ToString());
-                    PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.requestWritten(requestId);
-                    processContext.process.get().writeInferenceRequest(request);
-                    waitForResult(processContext, pendingResult, requestId, timeout, processor.getResultProcessor(), listener);
+                    NlpTask.Request request = processor.getRequestBuilder().buildRequest(text, requestId);
+                    logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
+                    PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.registerRequest(requestId);
+                    processContext.process.get().writeInferenceRequest(request.processInput);
+                    waitForResult(processContext, pendingResult, request.tokenization, requestId, timeout, processor.getResultProcessor(),
+                        listener);
                 } catch (IOException e) {
                     logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.modelId), e);
                     onFailure(ExceptionsHelper.serverError("error writing to process", e));
@@ -245,6 +245,7 @@ public class DeploymentManager {
 
     private void waitForResult(ProcessContext processContext,
                                PyTorchResultProcessor.PendingResult pendingResult,
+                               BertTokenizer.TokenizationResult tokenization,
                                String requestId,
                                TimeValue timeout,
                                NlpTask.ResultProcessor inferenceResultsProcessor,
@@ -269,7 +270,7 @@ public class DeploymentManager {
             }
 
             logger.debug(() -> new ParameterizedMessage("[{}] retrieved result for request [{}]", processContext.modelId, requestId));
-            InferenceResults results = inferenceResultsProcessor.processResult(pyTorchResult);
+            InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
             logger.debug(() -> new ParameterizedMessage("[{}] processed result for request [{}]", processContext.modelId, requestId));
             listener.onResponse(results);
         } catch (InterruptedException e) {

+ 3 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilder.java

@@ -25,7 +25,6 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
     static final String ARG3 = "arg_3";
 
     private final BertTokenizer tokenizer;
-    private BertTokenizer.TokenizationResult tokenization;
     private final int maxSequenceLength;
 
     public BertRequestBuilder(BertTokenizer tokenizer, int maxSequenceLength) {
@@ -33,19 +32,15 @@ public class BertRequestBuilder implements NlpTask.RequestBuilder {
         this.maxSequenceLength = maxSequenceLength;
     }
 
-    public BertTokenizer.TokenizationResult getTokenization() {
-        return tokenization;
-    }
-
     @Override
-    public BytesReference buildRequest(String input, String requestId) throws IOException {
-        tokenization = tokenizer.tokenize(input);
+    public NlpTask.Request buildRequest(String input, String requestId) throws IOException {
+        BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize(input);
         if (tokenization.getTokenIds().length > maxSequenceLength) {
             throw ExceptionsHelper.badRequestException(
                 "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
                 tokenization.getTokenIds().length, maxSequenceLength);
         }
-        return jsonRequest(tokenization.getTokenIds(), requestId);
+        return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
     }
 
     static BytesReference jsonRequest(int[] tokens, String requestId) throws IOException {

+ 4 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -7,10 +7,10 @@
 
 package org.elasticsearch.xpack.ml.inference.nlp;
 
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
+import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 
 import java.util.ArrayList;
@@ -51,11 +51,10 @@ public class FillMaskProcessor implements NlpTask.Processor {
 
     @Override
     public NlpTask.ResultProcessor getResultProcessor() {
-        return (pyTorchResult) -> processResult(bertRequestBuilder.getTokenization(), pyTorchResult);
+        return (tokenization, pyTorchResult) -> processResult(tokenization, pyTorchResult);
     }
 
-    InferenceResults processResult(BertTokenizer.TokenizationResult tokenization,
-                                           PyTorchResult pyTorchResult) {
+    InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
 
         if (tokenization.getTokens().isEmpty()) {
             return new FillMaskResults(Collections.emptyList());

+ 5 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -23,7 +23,6 @@ import java.util.Collections;
 import java.util.EnumSet;
 import java.util.List;
 import java.util.Locale;
-import java.util.Objects;
 
 public class NerProcessor implements NlpTask.Processor {
 
@@ -130,21 +129,19 @@ public class NerProcessor implements NlpTask.Processor {
 
     @Override
     public NlpTask.ResultProcessor getResultProcessor() {
-        return new NerResultProcessor(bertRequestBuilder.getTokenization(), iobMap);
+        return new NerResultProcessor(iobMap);
     }
 
     static class NerResultProcessor implements NlpTask.ResultProcessor {
 
-        private final BertTokenizer.TokenizationResult tokenization;
         private final IobTag[] iobMap;
 
-        NerResultProcessor(BertTokenizer.TokenizationResult tokenization, IobTag[] iobMap) {
-            this.tokenization = Objects.requireNonNull(tokenization);
+        NerResultProcessor(IobTag[] iobMap) {
             this.iobMap = iobMap;
         }
 
         @Override
-        public InferenceResults processResult(PyTorchResult pyTorchResult) {
+        public InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
             if (tokenization.getTokens().isEmpty()) {
                 return new NerResults(Collections.emptyList());
             }
@@ -155,7 +152,7 @@ public class NerProcessor implements NlpTask.Processor {
             // of maybe (1 + 0) / 2 = 0.5 while before softmax it'd be exp(10 - 5) / normalization
             // which could easily be close to 1.
             double[][] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult());
-            List<TaggedToken> taggedTokens = tagTokens(normalizedScores);
+            List<TaggedToken> taggedTokens = tagTokens(tokenization, normalizedScores);
             List<NerResults.EntityGroup> entities = groupTaggedTokens(taggedTokens);
             return new NerResults(entities);
         }
@@ -166,7 +163,7 @@ public class NerProcessor implements NlpTask.Processor {
          * in the original input replacing them with a single token that
          * gets labelled based on the average score of all its sub-tokens.
          */
-        private List<TaggedToken> tagTokens(double[][] scores) {
+        private List<TaggedToken> tagTokens(BertTokenizer.TokenizationResult tokenization, double[][] scores) {
             List<TaggedToken> taggedTokens = new ArrayList<>();
             int startTokenIndex = 0;
             while (startTokenIndex < tokenization.getTokens().size()) {

+ 13 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -19,6 +19,7 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 
 import java.io.IOException;
 import java.util.Map;
+import java.util.Objects;
 
 public class NlpTask {
 
@@ -43,11 +44,11 @@ public class NlpTask {
     }
 
     public interface RequestBuilder {
-        BytesReference buildRequest(String inputs, String requestId) throws IOException;
+        Request buildRequest(String inputs, String requestId) throws IOException;
     }
 
     public interface ResultProcessor {
-        InferenceResults processResult(PyTorchResult pyTorchResult);
+        InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult);
     }
 
     public interface Processor {
@@ -75,4 +76,14 @@ public class NlpTask {
         }
         throw ExceptionsHelper.badRequestException("input value [{}] for field [{}] is not a string", inputValue, inputField);
     }
+
+    public static class Request {
+        public final BertTokenizer.TokenizationResult tokenization;
+        public final BytesReference processInput;
+
+        public Request(BertTokenizer.TokenizationResult tokenization, BytesReference processInput) {
+            this.tokenization = Objects.requireNonNull(tokenization);
+            this.processInput = Objects.requireNonNull(processInput);
+        }
+    }
 }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -40,7 +40,7 @@ public class PassThroughProcessor implements NlpTask.Processor {
         return this::processResult;
     }
 
-    private InferenceResults processResult(PyTorchResult pyTorchResult) {
+    private InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
         return new PyTorchPassThroughResults(pyTorchResult.getInferenceResult());
     }
 }

+ 3 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessor.java

@@ -59,9 +59,9 @@ public class SentimentAnalysisProcessor implements NlpTask.Processor {
         return this::buildRequest;
     }
 
-    BytesReference buildRequest(String input, String requestId) throws IOException {
+    NlpTask.Request buildRequest(String input, String requestId) throws IOException {
         BertTokenizer.TokenizationResult tokenization = tokenizer.tokenize(input);
-        return jsonRequest(tokenization.getTokenIds(), requestId);
+        return new NlpTask.Request(tokenization, jsonRequest(tokenization.getTokenIds(), requestId));
     }
 
     @Override
@@ -69,7 +69,7 @@ public class SentimentAnalysisProcessor implements NlpTask.Processor {
         return this::processResult;
     }
 
-    InferenceResults processResult(PyTorchResult pyTorchResult) {
+    InferenceResults processResult(BertTokenizer.TokenizationResult tokenization, PyTorchResult pyTorchResult) {
         if (pyTorchResult.getInferenceResult().length < 1) {
             return new WarningInferenceResults("Sentiment analysis result has no data");
         }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -38,7 +38,7 @@ public class PyTorchResultProcessor {
         this.summaryStatistics = new LongSummaryStatistics();
     }
 
-    public PendingResult requestWritten(String requestId) {
+    public PendingResult registerRequest(String requestId) {
         return pendingResults.computeIfAbsent(requestId, k -> new PendingResult());
     }
 

+ 2 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/BertRequestBuilderTests.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
@@ -28,9 +27,9 @@ public class BertRequestBuilderTests extends ESTestCase {
             Arrays.asList("Elastic", "##search", "fun", BertTokenizer.CLASS_TOKEN, BertTokenizer.SEPARATOR_TOKEN)).build();
 
         BertRequestBuilder requestBuilder = new BertRequestBuilder(tokenizer, 512);
-        BytesReference bytesReference = requestBuilder.buildRequest("Elasticsearch fun", "request1");
+        NlpTask.Request request = requestBuilder.buildRequest("Elasticsearch fun", "request1");
 
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(5));
         assertEquals("request1", jsonDocAsMap.get("request_id"));

+ 13 - 23
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -80,14 +80,16 @@ public class NerProcessorTests extends ESTestCase {
     }
 
     public void testProcessResults_GivenNoTokens() {
-        NerProcessor.NerResultProcessor processor = createProcessor(Collections.emptyList(), "");
-        NerResults result = (NerResults) processor.processResult(new PyTorchResult("test", null, 0L, null));
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values());
+        BertTokenizer.TokenizationResult tokenization = tokenize(Collections.emptyList(), "");
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("test", null, 0L, null));
         assertThat(result.getEntityGroups(), is(empty()));
     }
 
     public void testProcessResults() {
-        NerProcessor.NerResultProcessor processor =
-            createProcessor(Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"), "Many use Elasticsearch in London");
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values());
+        BertTokenizer.TokenizationResult tokenization = tokenize(Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
+            "Many use Elasticsearch in London");
         double[][] scores = {
             { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // many
             { 7, 0, 0, 0, 0, 0, 0, 0, 0}, // use
@@ -97,7 +99,7 @@ public class NerProcessorTests extends ESTestCase {
             { 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in
             { 0, 0, 0, 0, 0, 0, 0, 6, 0} // london
         };
-        NerResults result = (NerResults) processor.processResult(new PyTorchResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 1L, null));
 
         assertThat(result.getEntityGroups().size(), equalTo(2));
         assertThat(result.getEntityGroups().get(0).getWord(), equalTo("elasticsearch"));
@@ -120,11 +122,9 @@ public class NerProcessorTests extends ESTestCase {
             NerProcessor.IobTag.O
         };
 
-        NerProcessor.NerResultProcessor processor = createProcessor(
-            Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
-            "Elasticsearch in London",
-            iobMap
-        );
+        NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(iobMap);
+        BertTokenizer.TokenizationResult tokenization = tokenize(Arrays.asList("el", "##astic", "##search", "many", "use", "in", "london"),
+            "Elasticsearch in London");
 
         double[][] scores = {
             { 0.01, 0.01, 0, 0.01, 0, 0, 7, 3, 0}, // el
@@ -133,7 +133,7 @@ public class NerProcessorTests extends ESTestCase {
             { 0, 0, 0, 0, 0, 0, 0, 0, 5}, // in
             { 6, 0, 0, 0, 0, 0, 0, 0, 0} // london
         };
-        NerResults result = (NerResults) processor.processResult(new PyTorchResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 1L, null));
 
         assertThat(result.getEntityGroups().size(), equalTo(2));
         assertThat(result.getEntityGroups().get(0).getWord(), equalTo("elasticsearch"));
@@ -210,21 +210,11 @@ public class NerProcessorTests extends ESTestCase {
         assertThat(entityGroups.get(2).getLabel(), equalTo("organisation"));
     }
 
-    private static NerProcessor.NerResultProcessor createProcessor(List<String> vocab, String input){
+    private static BertTokenizer.TokenizationResult tokenize(List<String> vocab, String input) {
         BertTokenizer tokenizer = BertTokenizer.builder(vocab)
             .setDoLowerCase(true)
             .setWithSpecialTokens(false)
             .build();
-        BertTokenizer.TokenizationResult tokenizationResult = tokenizer.tokenize(input);
-        return new NerProcessor.NerResultProcessor(tokenizationResult, NerProcessor.IobTag.values());
-    }
-
-    private static NerProcessor.NerResultProcessor createProcessor(List<String> vocab, String input, NerProcessor.IobTag[] iobMap){
-        BertTokenizer tokenizer = BertTokenizer.builder(vocab)
-            .setDoLowerCase(true)
-            .setWithSpecialTokens(false)
-            .build();
-        BertTokenizer.TokenizationResult tokenizationResult = tokenizer.tokenize(input);
-        return new NerProcessor.NerResultProcessor(tokenizationResult, iobMap);
+        return tokenizer.tokenize(input);
     }
 }

+ 4 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SentimentAnalysisProcessorTests.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.xpack.ml.inference.nlp;
 
 import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
@@ -36,14 +35,14 @@ public class SentimentAnalysisProcessorTests extends ESTestCase {
         SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(mock(BertTokenizer.class), config);
         {
             PyTorchResult torchResult = new PyTorchResult("foo", new double[][]{}, 0L, null);
-            InferenceResults inferenceResults = processor.processResult(torchResult);
+            InferenceResults inferenceResults = processor.processResult(null, torchResult);
             assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
             assertEquals("Sentiment analysis result has no data",
                 ((WarningInferenceResults) inferenceResults).getWarning());
         }
         {
             PyTorchResult torchResult = new PyTorchResult("foo", new double[][]{{1.0}}, 0L, null);
-            InferenceResults inferenceResults = processor.processResult(torchResult);
+            InferenceResults inferenceResults = processor.processResult(null, torchResult);
             assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
             assertEquals("Expected 2 values in sentiment analysis result",
                 ((WarningInferenceResults)inferenceResults).getWarning());
@@ -57,9 +56,9 @@ public class SentimentAnalysisProcessorTests extends ESTestCase {
         SentimentAnalysisConfig config = new SentimentAnalysisConfig(new VocabularyConfig("test-index", "vocab"), null, null);
         SentimentAnalysisProcessor processor = new SentimentAnalysisProcessor(tokenizer, config);
 
-        BytesReference bytesReference = processor.buildRequest("Elasticsearch fun", "request1");
+        NlpTask.Request request = processor.buildRequest("Elasticsearch fun", "request1");
 
-        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(bytesReference, true, XContentType.JSON).v2();
+        Map<String, Object> jsonDocAsMap = XContentHelper.convertToMap(request.processInput, true, XContentType.JSON).v2();
 
         assertThat(jsonDocAsMap.keySet(), hasSize(3));
         assertEquals("request1", jsonDocAsMap.get("request_id"));