1
0
Эх сурвалжийг харах

[ML] Fix out of bounds error chunking an empty string (#110033)

An empty string tokenises to 0 tokens (excluding the special marker tokens), this 
case must be handled when restoring the chunked inputs
David Kyle 1 жил өмнө
parent
commit
e63f0c535d

+ 12 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -62,10 +62,18 @@ public class TextEmbeddingProcessor extends NlpTask.Processor {
         if (chunkResults) {
         if (chunkResults) {
             var embeddings = new ArrayList<MlChunkedTextEmbeddingFloatResults.EmbeddingChunk>();
             var embeddings = new ArrayList<MlChunkedTextEmbeddingFloatResults.EmbeddingChunk>();
             for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; i++) {
             for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; i++) {
-                int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset();
-                int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1;
-                int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset();
-                String matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset);
+                String matchedText;
+                if (tokenization.getTokenization(i).tokens().get(0).isEmpty() == false) {
+                    int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset();
+                    int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1;
+                    int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset();
+                    matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset);
+
+                } else {
+                    // No tokens in the input, this should only happen with and empty string
+                    assert tokenization.getTokenization(i).input().get(0).isEmpty();
+                    matchedText = "";
+                }
 
 
                 embeddings.add(
                 embeddings.add(
                     new MlChunkedTextEmbeddingFloatResults.EmbeddingChunk(matchedText, pyTorchResult.getInferenceResult()[0][i])
                     new MlChunkedTextEmbeddingFloatResults.EmbeddingChunk(matchedText, pyTorchResult.getInferenceResult()[0][i])

+ 11 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java

@@ -75,10 +75,17 @@ public class TextExpansionProcessor extends NlpTask.Processor {
             var chunkedResults = new ArrayList<MlChunkedTextExpansionResults.ChunkedResult>();
             var chunkedResults = new ArrayList<MlChunkedTextExpansionResults.ChunkedResult>();
 
 
             for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; i++) {
             for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; i++) {
-                int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset();
-                int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1;
-                int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset();
-                String matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset);
+                String matchedText;
+                if (tokenization.getTokenization(i).tokens().get(0).isEmpty() == false) {
+                    int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset();
+                    int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1;
+                    int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset();
+                    matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset);
+                } else {
+                    // No tokens in the input, this should only happen with and empty string
+                    assert tokenization.getTokenization(i).input().get(0).isEmpty();
+                    matchedText = "";
+                }
 
 
                 var weightedTokens = sparseVectorToTokenWeights(pyTorchResult.getInferenceResult()[0][i], tokenization, replacementVocab);
                 var weightedTokens = sparseVectorToTokenWeights(pyTorchResult.getInferenceResult()[0][i], tokenization, replacementVocab);
                 weightedTokens.sort((t1, t2) -> Float.compare(t2.weight(), t1.weight()));
                 weightedTokens.sort((t1, t2) -> Float.compare(t2.weight(), t1.weight()));

+ 27 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.inference.nlp;
 
 
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
@@ -16,9 +17,13 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResul
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 
+import java.util.Map;
+
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.core.IsNot.not;
 
 
 public class TextEmbeddingProcessorTests extends ESTestCase {
 public class TextEmbeddingProcessorTests extends ESTestCase {
 
 
@@ -67,4 +72,26 @@ public class TextEmbeddingProcessorTests extends ESTestCase {
             assertThat(chunkedResult.getChunks().get(1).embedding().length, greaterThan(0));
             assertThat(chunkedResult.getChunks().get(1).embedding().length, greaterThan(0));
         }
         }
     }
     }
+
+    public void testChunkingWithEmptyString() {
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(
+                TextExpansionProcessorTests.TEST_CASED_VOCAB,
+                new BertTokenization(null, false, 5, Tokenization.Truncate.NONE, 0)
+            ).build()
+        ) {
+            var pytorchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0, 2.0, 3.0, 4.0, 5.0 } } });
+
+            var input = "";
+            var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null);
+            var tokenizationResult = new BertTokenizationResult(TextExpansionProcessorTests.TEST_CASED_VOCAB, tokenization, 0);
+            var inferenceResult = TextExpansionProcessor.processResult(tokenizationResult, pytorchResult, Map.of(), "foo", true);
+            assertThat(inferenceResult, instanceOf(MlChunkedTextExpansionResults.class));
+
+            var chunkedResult = (MlChunkedTextExpansionResults) inferenceResult;
+            assertThat(chunkedResult.getChunks(), hasSize(1));
+            assertEquals("", chunkedResult.getChunks().get(0).matchedText());
+            assertThat(chunkedResult.getChunks().get(0).weightedTokens(), not(empty()));
+        }
+    }
 }
 }

+ 22 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java

@@ -147,4 +147,26 @@ public class TextExpansionProcessorTests extends ESTestCase {
             assertThat(chunkedResult.getChunks().get(1).weightedTokens(), not(empty()));
             assertThat(chunkedResult.getChunks().get(1).weightedTokens(), not(empty()));
         }
         }
     }
     }
+
+    public void testChunkingWithEmptyString() {
+        try (
+            BertTokenizer tokenizer = BertTokenizer.builder(
+                TEST_CASED_VOCAB,
+                new BertTokenization(null, false, 5, Tokenization.Truncate.NONE, 0)
+            ).build()
+        ) {
+            var pytorchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0, 2.0, 3.0, 4.0, 5.0 } } });
+
+            var input = "";
+            var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null);
+            var tokenizationResult = new BertTokenizationResult(TEST_CASED_VOCAB, tokenization, 0);
+            var inferenceResult = TextExpansionProcessor.processResult(tokenizationResult, pytorchResult, Map.of(), "foo", true);
+            assertThat(inferenceResult, instanceOf(MlChunkedTextExpansionResults.class));
+
+            var chunkedResult = (MlChunkedTextExpansionResults) inferenceResult;
+            assertThat(chunkedResult.getChunks(), hasSize(1));
+            assertEquals("", chunkedResult.getChunks().get(0).matchedText());
+            assertThat(chunkedResult.getChunks().get(0).weightedTokens(), not(empty()));
+        }
+    }
 }
 }