|
@@ -87,48 +87,78 @@ public class QuestionAnsweringProcessor extends NlpTask.Processor {
|
|
|
if (pyTorchResult.getInferenceResult().length < 1) {
|
|
|
throw new ElasticsearchStatusException("question answering result has no data", RestStatus.INTERNAL_SERVER_ERROR);
|
|
|
}
|
|
|
+
|
|
|
+ // The result format is pairs of 'start' and 'end' logits,
|
|
|
+ // one pair for each span.
|
|
|
+ // Multiple spans occur where the context text is longer than
|
|
|
+ // the max sequence length, so the input must be windowed with
|
|
|
+ // overlap and evaluated in multiple calls.
|
|
|
+ // Note the response format changed in 8.9 due to the change in
|
|
|
+ // pytorch_inference to not process requests in batches.
|
|
|
+
|
|
|
+ // The output tensor is a 3d array of doubles.
|
|
|
+ // 1. The 1st index is the pairs of start and end for each span.
|
|
|
+ // If there is 1 span there will be 2 elements in this dimension,
|
|
|
+ // for 2 spans 4 elements
|
|
|
+ // 2. The 2nd index is the number results per span.
|
|
|
+ // This dimension is always equal to 1.
|
|
|
+ // 3. The 3rd index is the actual scores.
|
|
|
+ // This is an array of doubles equal in size to the number of
|
|
|
+ // input tokens plus and delimiters (e.g. SEP and CLS tokens)
|
|
|
+ // added by the tokenizer.
|
|
|
+ //
|
|
|
+ // inferenceResult[span_index_start_end][0][scores]
|
|
|
+
|
|
|
// Should be a collection of "starts" and "ends"
|
|
|
- if (pyTorchResult.getInferenceResult().length != 2) {
|
|
|
+ if (pyTorchResult.getInferenceResult().length % 2 != 0) {
|
|
|
throw new ElasticsearchStatusException(
|
|
|
- "question answering result has invalid dimension, expected 2 found [{}]",
|
|
|
+ "question answering result has invalid dimension, number of dimensions must be a multiple of 2 found [{}]",
|
|
|
RestStatus.INTERNAL_SERVER_ERROR,
|
|
|
pyTorchResult.getInferenceResult().length
|
|
|
);
|
|
|
}
|
|
|
- double[][] starts = pyTorchResult.getInferenceResult()[0];
|
|
|
- double[][] ends = pyTorchResult.getInferenceResult()[1];
|
|
|
- if (starts.length != ends.length) {
|
|
|
- throw new ElasticsearchStatusException(
|
|
|
- "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
|
|
|
- RestStatus.INTERNAL_SERVER_ERROR,
|
|
|
- starts.length,
|
|
|
- ends.length
|
|
|
- );
|
|
|
- }
|
|
|
+
|
|
|
+ final int numAnswersToGather = Math.max(numTopClasses, 1);
|
|
|
+ ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
|
|
|
List<TokenizationResult.Tokens> tokensList = tokenization.getTokensBySequenceId().get(0);
|
|
|
- if (starts.length != tokensList.size()) {
|
|
|
+
|
|
|
+ int numberOfSpans = pyTorchResult.getInferenceResult().length / 2;
|
|
|
+ if (numberOfSpans != tokensList.size()) {
|
|
|
throw new ElasticsearchStatusException(
|
|
|
- "question answering result has invalid dimensions; start positions number [{}] equal batched token size [{}]",
|
|
|
+ "question answering result has invalid dimensions; the number of spans [{}] does not match batched token size [{}]",
|
|
|
RestStatus.INTERNAL_SERVER_ERROR,
|
|
|
- starts.length,
|
|
|
+ numberOfSpans,
|
|
|
tokensList.size()
|
|
|
);
|
|
|
}
|
|
|
- final int numAnswersToGather = Math.max(numTopClasses, 1);
|
|
|
|
|
|
- ScoreAndIndicesPriorityQueue finalEntries = new ScoreAndIndicesPriorityQueue(numAnswersToGather);
|
|
|
- for (int i = 0; i < starts.length; i++) {
|
|
|
+ for (int spanIndex = 0; spanIndex < numberOfSpans; spanIndex++) {
|
|
|
+ double[][] starts = pyTorchResult.getInferenceResult()[spanIndex * 2];
|
|
|
+ double[][] ends = pyTorchResult.getInferenceResult()[(spanIndex * 2) + 1];
|
|
|
+ assert starts.length == 1;
|
|
|
+ assert ends.length == 1;
|
|
|
+
|
|
|
+ if (starts.length != ends.length) {
|
|
|
+ throw new ElasticsearchStatusException(
|
|
|
+ "question answering result has invalid dimensions; start positions [{}] must equal potential end [{}]",
|
|
|
+ RestStatus.INTERNAL_SERVER_ERROR,
|
|
|
+ starts.length,
|
|
|
+ ends.length
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
topScores(
|
|
|
- starts[i],
|
|
|
- ends[i],
|
|
|
+ starts[0], // always 1 element in this dimension
|
|
|
+ ends[0],
|
|
|
numAnswersToGather,
|
|
|
finalEntries::insertWithOverflow,
|
|
|
- tokensList.get(i).seqPairOffset(),
|
|
|
- tokensList.get(i).tokenIds().length,
|
|
|
+ tokensList.get(spanIndex).seqPairOffset(),
|
|
|
+ tokensList.get(spanIndex).tokenIds().length,
|
|
|
maxAnswerLength,
|
|
|
- i
|
|
|
+ spanIndex
|
|
|
);
|
|
|
}
|
|
|
+
|
|
|
QuestionAnsweringInferenceResults.TopAnswerEntry[] topAnswerList =
|
|
|
new QuestionAnsweringInferenceResults.TopAnswerEntry[numAnswersToGather];
|
|
|
for (int i = numAnswersToGather - 1; i >= 0; i--) {
|