Browse Source

[ML] Adapt Question Answering processing for non-batched evaluation (#98167)

David Kyle 2 years ago
parent
commit
51f1a989cc

+ 6 - 0
docs/changelog/98167.yaml

@@ -0,0 +1,6 @@
+pr: 98167
+summary: Fix failure processing Question Answering model output where the input has been spanned over multiple sequences   
+area: Machine Learning
+type: bug
+issues:
+ - 97917

+ 53 - 23
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java

@@ -87,48 +87,78 @@ public class QuestionAnsweringProcessor extends NlpTask.Processor {
             if (pyTorchResult.getInferenceResult().length < 1) {
                 throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR);
             }
+
+            // The result format is pairs of 'start' and 'end' logits,
+            // one pair for each span.
+            // Multiple spans occur where the context text is longer than
+            // the max sequence length, so the input must be windowed with
+            // overlap and evaluated in multiple calls.
+            // Note the response format changed in 8.9 due to the change in
+            // pytorch_inference to not process requests in batches.
+
+            // The output tensor is a 3d array of doubles.
+            // 1. The 1st index is the pairs of start and end for each span.
+            // If there is 1 span there will be 2 elements in this dimension,
+            // for 2 spans 4 elements
+            // 2. The 2nd index is the number results per span.
+            // This dimension is always equal to 1.
+            // 3. The 3rd index is the actual scores.
+            // This is an array of doubles equal in size to the number of
+            // input tokens plus and delimiters (e.g. SEP and CLS tokens)
+            // added by the tokenizer.
+            //
+            // inferenceResult[span_index_start_end][0][scores]
+
             // Should be a collection of "starts" and "ends"
-            if (pyTorchResult.getInferenceResult().length != 2) {
+            if (pyTorchResult.getInferenceResult().length % 2 != 0) {
                 throw new ElasticsearchStatusException(
-                    "question answering result has invalid dimension, expected 2 found [{}]",
+                    "question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]",
                     RestStatus.INTERNAL_SERVER_ERROR,
                     pyTorchResult.getInferenceResult().length
                 );
             }
-            double[][] starts = pyTorchResult.getInferenceResult()[0];
-            double[][] ends = pyTorchResult.getInferenceResult()[1];
-            if (starts.length != ends.length) {
-                throw new ElasticsearchStatusException(
-                    "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
-                    RestStatus.INTERNAL_SERVER_ERROR,
-                    starts.length,
-                    ends.length
-                );
-            }
+
+            final int numAnswersToGather = Math.max(numTopClasses, 1);
+            ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
             List<TokenizationResult.Tokens> tokensList = tokenization.getTokensBySequenceId().get(0);
-            if (starts.length != tokensList.size()) {
+
+            int numberOfSpans = pyTorchResult.getInferenceResult().length / 2;
+            if (numberOfSpans != tokensList.size()) {
                 throw new ElasticsearchStatusException(
-                    "question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]",
+                    "question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]",
                     RestStatus.INTERNAL_SERVER_ERROR,
-                    starts.length,
+                    numberOfSpans,
                     tokensList.size()
                 );
             }
-            final int numAnswersToGather = Math.max(numTopClasses, 1);
 
-            ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
-            for (int i = 0; i < starts.length; i++) {
+            for (int spanIndex = 0; spanIndex < numberOfSpans; spanIndex++) {
+                double[][] starts = pyTorchResult.getInferenceResult()[spanIndex * 2];
+                double[][] ends = pyTorchResult.getInferenceResult()[(spanIndex * 2) + 1];
+                assert starts.length == 1;
+                assert ends.length == 1;
+
+                if (starts.length != ends.length) {
+                    throw new ElasticsearchStatusException(
+                        "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
+                        RestStatus.INTERNAL_SERVER_ERROR,
+                        starts.length,
+                        ends.length
+                    );
+                }
+
                 topScores(
-                    starts[i],
-                    ends[i],
+                    starts[0], // always 1 element in this dimension
+                    ends[0],
                     numAnswersToGather,
                     finalEntries::insertWithOverflow,
-                    tokensList.get(i).seqPairOffset(),
-                    tokensList.get(i).tokenIds().length,
+                    tokensList.get(spanIndex).seqPairOffset(),
+                    tokensList.get(spanIndex).tokenIds().length,
                     maxAnswerLength,
-                    i
+                    spanIndex
                 );
             }
+
             QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList =
                 new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather];
             for (int i = numAnswersToGather - 1; i >= 0; i--) {

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java

@@ -48,7 +48,7 @@ public abstract class TokenizationResult {
         return tokens.stream().collect(Collectors.groupingBy(Tokens::sequenceId));
     }
 
-    List<Tokens> getTokens() {
+    public List<Tokens> getTokens() {
         return tokens;
     }
 

+ 65 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java

@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResu
 import java.io.IOException;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.DoubleStream;
 
 import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
 import static org.hamcrest.Matchers.closeTo;
@@ -168,4 +169,68 @@ public class QuestionAnsweringProcessorTests extends ESTestCase {
         assertThat(topScores[1].endToken(), equalTo(5));
     }
 
+    public void testProcessorMuliptleSpans() throws IOException {
+        String question = "is Elasticsearch fun?";
+        String input = "Pancake day is fun with Elasticsearch and little red car";
+        int span = 4;
+        int maxSequenceLength = 14;
+        int numberTopClasses = 3;
+
+        BertTokenization tokenization = new BertTokenization(false, true, maxSequenceLength, Tokenization.Truncate.NONE, span);
+        BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();
+        QuestionAnsweringConfig config = new QuestionAnsweringConfig(
+            question,
+            numberTopClasses,
+            10,
+            new VocabularyConfig("index_name"),
+            tokenization,
+            "prediction"
+        );
+        QuestionAnsweringProcessor processor = new QuestionAnsweringProcessor(tokenizer);
+        TokenizationResult tokenizationResult = processor.getRequestBuilder(config)
+            .buildRequest(List.of(input), "1", Tokenization.Truncate.NONE, span)
+            .tokenization();
+        assertThat(tokenizationResult.anyTruncated(), is(false));
+
+        // now we know what the tokenization looks like
+        // (number of spans and size of each) fake the
+        // question answering response
+
+        int numberSpans = tokenizationResult.getTokens().size();
+        double[][][] modelTensorOutput = new double[numberSpans * 2][][];
+        for (int i = 0; i < numberSpans; i++) {
+            var windowTokens = tokenizationResult.getTokens().get(i);
+            // size of output
+            int outputSize = windowTokens.tokenIds().length;
+            // generate low value -ve scores that will not mark
+            // the expected result with a high degree of probability
+            double[] starts = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray();
+            double[] ends = DoubleStream.generate(() -> -randomDoubleBetween(0.001, 1.0, true)).limit(outputSize).toArray();
+            modelTensorOutput[i * 2] = new double[][] { starts };
+            modelTensorOutput[(i * 2) + 1] = new double[][] { ends };
+        }
+
+        int spanContainingTheAnswer = randomIntBetween(0, numberSpans - 1);
+
+        // insert numbers to mark the answer in the chosen span
+        int answerStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).seqPairOffset(); // first token of second sequence
+        // last token of the second sequence ignoring the final SEP added by the BERT tokenizer
+        int answerEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokenIds().length - 2;
+        modelTensorOutput[spanContainingTheAnswer * 2][0][answerStart] = 0.5;
+        modelTensorOutput[(spanContainingTheAnswer * 2) + 1][0][answerEnd] = 1.0;
+
+        NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(modelTensorOutput);
+        QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult(
+            tokenizationResult,
+            pyTorchResult
+        );
+
+        // The expected answer is the full text of the span containing the answer
+        int expectedStart = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(0).startOffset();
+        int lastTokenPosition = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).size() - 1;
+        int expectedEnd = tokenizationResult.getTokens().get(spanContainingTheAnswer).tokens().get(1).get(lastTokenPosition).endOffset();
+
+        assertThat(result.getAnswer(), equalTo(input.substring(expectedStart, expectedEnd)));
+    }
 }