Browse Source

[ML] Span and merge long docs for the text expansion model (#94224)

Merge duplicate tokens by choosing the highest weighted token
David Kyle 2 years ago
parent
commit
f8918256c8

+ 0 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfig.java

@@ -77,14 +77,6 @@ public class SlimConfig implements NlpConfig {
                 this.tokenization.getName()
             );
         }
-        // TODO support spanning
-        if (this.tokenization.span != -1) {
-            throw ExceptionsHelper.badRequestException(
-                "[{}] does not support windowing long text sequences; configured span [{}]",
-                NAME,
-                this.tokenization.span
-            );
-        }
         this.resultsField = resultsField;
     }
 

+ 4 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java

@@ -60,7 +60,7 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
             try {
                 cancellableTask.ensureNotCancelled();
             } catch (TaskCancelledException ex) {
-                logger.debug(() -> format("[%s] %s", getModelId(), ex.getMessage()));
+                logger.warn(() -> format("[%s] %s", getModelId(), ex.getMessage()));
                 return true;
             }
         }
@@ -90,7 +90,8 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
             NlpConfig nlpConfig = (NlpConfig) config;
             NlpTask.Request request = processor.getRequestBuilder(nlpConfig)
                 .buildRequest(text, requestIdStr, nlpConfig.getTokenization().getTruncate(), nlpConfig.getTokenization().getSpan());
-            logger.debug(() -> "Inference Request " + request.processInput().utf8ToString());
+            logger.debug(() -> format("handling request [%s]", requestIdStr));
+            logger.trace(() -> "Inference Request " + request.processInput().utf8ToString());
             if (request.tokenization().anyTruncated()) {
                 logger.debug("[{}] [{}] input truncated", getModelId(), getRequestId());
             }
@@ -140,7 +141,7 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
             return;
         }
         InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult.inferenceResult());
-        logger.debug(() -> format("[%s] processed result for request [%s]", getModelId(), getRequestId()));
+        logger.trace(() -> format("[%s] processed result for request [%s]", getModelId(), getRequestId()));
         onSuccess(results);
     }
 

+ 29 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessor.java

@@ -43,14 +43,11 @@ public class SlimProcessor extends NlpTask.Processor {
     }
 
     static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, String resultsField) {
-        // Convert the verbose results to the sparse format.
-        // Anything with a score > 0.0 is retained.
-        List<SlimResults.WeightedToken> weightedTokens = new ArrayList<>();
-        double[] weights = pyTorchResult.getInferenceResult()[0][0];
-        for (int i = 0; i < weights.length; i++) {
-            if (weights[i] > 0.0) {
-                weightedTokens.add(new SlimResults.WeightedToken(i, (float) weights[i]));
-            }
+        List<SlimResults.WeightedToken> weightedTokens;
+        if (pyTorchResult.getInferenceResult()[0].length == 1) {
+            weightedTokens = sparseVectorToTokenWeights(pyTorchResult.getInferenceResult()[0][0]);
+        } else {
+            weightedTokens = multipleSparseVectorsToTokenWeights(pyTorchResult.getInferenceResult()[0]);
         }
 
         return new SlimResults(
@@ -59,4 +56,28 @@ public class SlimProcessor extends NlpTask.Processor {
             tokenization.anyTruncated()
         );
     }
+
+    static List<SlimResults.WeightedToken> multipleSparseVectorsToTokenWeights(double[][] vector) {
+        // reduce to a single 1d array choosing the max value
+        // in each column and placing that in the first row
+        for (int i = 1; i < vector.length; i++) {
+            for (int tokenId = 0; tokenId < vector[i].length; tokenId++) {
+                if (vector[i][tokenId] > vector[0][tokenId]) {
+                    vector[0][tokenId] = vector[i][tokenId];
+                }
+            }
+        }
+        return sparseVectorToTokenWeights(vector[0]);
+    }
+
+    static List<SlimResults.WeightedToken> sparseVectorToTokenWeights(double[] vector) {
+        // Anything with a score > 0.0 is retained.
+        List<SlimResults.WeightedToken> weightedTokens = new ArrayList<>();
+        for (int i = 0; i < vector.length; i++) {
+            if (vector[i] > 0.0) {
+                weightedTokens.add(new SlimResults.WeightedToken(i, (float) vector[i]));
+            }
+        }
+        return weightedTokens;
+    }
 }

+ 19 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/SlimProcessorTests.java

@@ -36,4 +36,23 @@ public class SlimProcessorTests extends ESTestCase {
         assertEquals(new SlimResults.WeightedToken(3, 3.0f), weightedTokens.get(1));
         assertEquals(new SlimResults.WeightedToken(4, 4.0f), weightedTokens.get(2));
     }
+
+    public void testProcessResultMultipleVectors() {
+        double[][][] pytorchResult = new double[][][] { { { 0.0, 1.0, 0.0, 1.0, 4.0, 0.0, 0.0 }, { 1.0, 2.0, 0.0, 3.0, 4.0, 0.0, 0.1 } } };
+
+        TokenizationResult tokenizationResult = new BertTokenizationResult(List.of(), List.of(), 0);
+
+        var inferenceResult = SlimProcessor.processResult(tokenizationResult, new PyTorchInferenceResult(pytorchResult), "foo");
+        assertThat(inferenceResult, instanceOf(SlimResults.class));
+        var slimResults = (SlimResults) inferenceResult;
+        assertEquals(slimResults.getResultsField(), "foo");
+
+        var weightedTokens = slimResults.getWeightedTokens();
+        assertThat(weightedTokens, hasSize(5));
+        assertEquals(new SlimResults.WeightedToken(0, 1.0f), weightedTokens.get(0));
+        assertEquals(new SlimResults.WeightedToken(1, 2.0f), weightedTokens.get(1));
+        assertEquals(new SlimResults.WeightedToken(3, 3.0f), weightedTokens.get(2));
+        assertEquals(new SlimResults.WeightedToken(4, 4.0f), weightedTokens.get(3));
+        assertEquals(new SlimResults.WeightedToken(6, 0.1f), weightedTokens.get(4));
+    }
 }

+ 19 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorServiceTests.java

@@ -13,6 +13,8 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
 import org.junit.After;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService.RequestPriority;
@@ -102,20 +104,29 @@ public class PriorityProcessWorkerExecutorServiceTests extends ESTestCase {
 
         var counter = new AtomicInteger();
         long requestId = 1;
-        var r1 = new RunOrderValidator(2, counter);
-        executor.executeWithPriority(r1, RequestPriority.NORMAL, requestId++);
-        executor.executeWithPriority(new RunOrderValidator(3, counter), RequestPriority.NORMAL, requestId++);
-        executor.executeWithPriority(new RunOrderValidator(4, counter), RequestPriority.NORMAL, requestId++);
-        executor.executeWithPriority(new RunOrderValidator(1, counter), RequestPriority.HIGH, requestId++);
-        executor.executeWithPriority(new RunOrderValidator(5, counter), RequestPriority.NORMAL, requestId++);
-        executor.executeWithPriority(new RunOrderValidator(6, counter), RequestPriority.NORMAL, requestId++);
+        List<RunOrderValidator> validators = new ArrayList<>();
+        validators.add(new RunOrderValidator(2, counter));
+        validators.add(new RunOrderValidator(3, counter));
+        validators.add(new RunOrderValidator(4, counter));
+        validators.add(new RunOrderValidator(1, counter));   // high priority request runs first
+        validators.add(new RunOrderValidator(5, counter));
+        validators.add(new RunOrderValidator(6, counter));
+
+        executor.executeWithPriority(validators.get(0), RequestPriority.NORMAL, requestId++);
+        executor.executeWithPriority(validators.get(1), RequestPriority.NORMAL, requestId++);
+        executor.executeWithPriority(validators.get(2), RequestPriority.NORMAL, requestId++);
+        executor.executeWithPriority(validators.get(3), RequestPriority.HIGH, requestId++);
+        executor.executeWithPriority(validators.get(4), RequestPriority.NORMAL, requestId++);
+        executor.executeWithPriority(validators.get(5), RequestPriority.NORMAL, requestId++);
 
         // final action stops the executor
         executor.executeWithPriority(new ShutdownExecutorRunnable(executor), RequestPriority.NORMAL, 10000L);
 
         executor.start();
 
-        assertTrue(r1.hasBeenRun);
+        for (var validator : validators) {
+            assertTrue(validator.hasBeenRun);
+        }
     }
 
     private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(int queueSize) {