Browse Source

[ML] Fix timeout ingesting an empty string into a semantic_text field (#117840) (#118540)

David Kyle 10 months ago
parent
commit
90305d2c9c

+ 5 - 0
docs/changelog/117840.yaml

@@ -0,0 +1,5 @@
+pr: 117840
+summary: Fix timeout ingesting an empty string into a `semantic_text` field
+area: Machine Learning
+type: bug
+issues: []

+ 7 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java

@@ -62,7 +62,8 @@ public class SentenceBoundaryChunker implements Chunker {
      *
      * @param input Text to chunk
      * @param maxNumberWordsPerChunk Maximum size of the chunk
-     * @return The input text chunked
+     * @param includePrecedingSentence Include the previous sentence
+     * @return The input text offsets
      */
     public List<ChunkOffset> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
         var chunks = new ArrayList<ChunkOffset>();
@@ -158,6 +159,11 @@ public class SentenceBoundaryChunker implements Chunker {
             chunks.add(new ChunkOffset(chunkStart, input.length()));
         }
 
+        if (chunks.isEmpty()) {
+            // The input did not chunk, return the entire input
+            chunks.add(new ChunkOffset(0, input.length()));
+        }
+
         return chunks;
     }
 

+ 0 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java

@@ -96,10 +96,6 @@ public class WordBoundaryChunker implements Chunker {
             throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0");
         }
 
-        if (input.isEmpty()) {
-            return List.of();
-        }
-
         var chunkPositions = new ArrayList<ChunkPosition>();
 
         // This position in the chunk is where the next overlapping chunk will start

+ 49 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

@@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByte
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.search.WeightedToken;
+import org.hamcrest.Matchers;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -31,16 +32,62 @@ import static org.hamcrest.Matchers.startsWith;
 
 public class EmbeddingRequestChunkerTests extends ESTestCase {
 
-    public void testEmptyInput() {
+    public void testEmptyInput_WordChunker() {
         var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
         var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
         assertThat(batches, empty());
     }
 
-    public void testBlankInput() {
+    public void testEmptyInput_SentenceChunker() {
+        var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
+        var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
+            .batchRequestsWithListeners(testListener());
+        assertThat(batches, empty());
+    }
+
+    public void testWhitespaceInput_SentenceChunker() {
+        var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
+        var batches = new EmbeddingRequestChunker(List.of("   "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
+            .batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(1));
+        assertThat(batches.get(0).batch().inputs(), hasSize(1));
+        assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("   "));
+    }
+
+    public void testBlankInput_WordChunker() {
         var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
         var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
         assertThat(batches, hasSize(1));
+        assertThat(batches.get(0).batch().inputs(), hasSize(1));
+        assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
+    }
+
+    public void testBlankInput_SentenceChunker() {
+        var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
+        var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
+            .batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(1));
+        assertThat(batches.get(0).batch().inputs(), hasSize(1));
+        assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
+    }
+
+    public void testInputThatDoesNotChunk_WordChunker() {
+        var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
+        var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
+            testListener()
+        );
+        assertThat(batches, hasSize(1));
+        assertThat(batches.get(0).batch().inputs(), hasSize(1));
+        assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
+    }
+
+    public void testInputThatDoesNotChunk_SentenceChunker() {
+        var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
+        var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
+            .batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(1));
+        assertThat(batches.get(0).batch().inputs(), hasSize(1));
+        assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
     }
 
     public void testShortInputsAreSingleBatch() {

+ 35 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java

@@ -43,6 +43,41 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         return chunkPositions.stream().map(offset -> input.substring(offset.start(), offset.end())).collect(Collectors.toList());
     }
 
+    public void testEmptyString() {
+        var chunks = textChunks(new SentenceBoundaryChunker(), "", 100, randomBoolean());
+        assertThat(chunks, hasSize(1));
+        assertThat(chunks.get(0), Matchers.is(""));
+    }
+
+    public void testBlankString() {
+        var chunks = textChunks(new SentenceBoundaryChunker(), "   ", 100, randomBoolean());
+        assertThat(chunks, hasSize(1));
+        assertThat(chunks.get(0), Matchers.is("   "));
+    }
+
+    public void testSingleChar() {
+        var chunks = textChunks(new SentenceBoundaryChunker(), "   b", 100, randomBoolean());
+        assertThat(chunks, Matchers.contains("   b"));
+
+        chunks = textChunks(new SentenceBoundaryChunker(), "b", 100, randomBoolean());
+        assertThat(chunks, Matchers.contains("b"));
+
+        chunks = textChunks(new SentenceBoundaryChunker(), ". ", 100, randomBoolean());
+        assertThat(chunks, Matchers.contains(". "));
+
+        chunks = textChunks(new SentenceBoundaryChunker(), " , ", 100, randomBoolean());
+        assertThat(chunks, Matchers.contains(" , "));
+
+        chunks = textChunks(new SentenceBoundaryChunker(), " ,", 100, randomBoolean());
+        assertThat(chunks, Matchers.contains(" ,"));
+    }
+
+    public void testSingleCharRepeated() {
+        var input = "a".repeat(32_000);
+        var chunks = textChunks(new SentenceBoundaryChunker(), input, 100, randomBoolean());
+        assertThat(chunks, Matchers.contains(input));
+    }
+
     public void testChunkSplitLargeChunkSizes() {
         for (int maxWordsPerChunk : new int[] { 100, 200 }) {
             var chunker = new SentenceBoundaryChunker();

+ 30 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java

@@ -11,6 +11,7 @@ import com.ibm.icu.text.BreakIterator;
 
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.test.ESTestCase;
+import org.hamcrest.Matchers;
 
 import java.util.List;
 import java.util.Locale;
@@ -71,10 +72,6 @@ public class WordBoundaryChunkerTests extends ESTestCase {
      * Use the chunk functions that return offsets where possible
      */
     List<String> textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) {
-        if (input.isEmpty()) {
-            return List.of("");
-        }
-
         var chunkPositions = chunker.chunk(input, chunkSize, overlap);
         return chunkPositions.stream().map(p -> input.substring(p.start(), p.end())).collect(Collectors.toList());
     }
@@ -240,6 +237,35 @@ public class WordBoundaryChunkerTests extends ESTestCase {
         assertThat(chunks, contains(" "));
     }
 
+    public void testBlankString() {
+        var chunks = textChunks(new WordBoundaryChunker(), "   ", 100, 10);
+        assertThat(chunks, hasSize(1));
+        assertThat(chunks.get(0), Matchers.is("   "));
+    }
+
+    public void testSingleChar() {
+        var chunks = textChunks(new WordBoundaryChunker(), "   b", 100, 10);
+        assertThat(chunks, Matchers.contains("   b"));
+
+        chunks = textChunks(new WordBoundaryChunker(), "b", 100, 10);
+        assertThat(chunks, Matchers.contains("b"));
+
+        chunks = textChunks(new WordBoundaryChunker(), ". ", 100, 10);
+        assertThat(chunks, Matchers.contains(". "));
+
+        chunks = textChunks(new WordBoundaryChunker(), " , ", 100, 10);
+        assertThat(chunks, Matchers.contains(" , "));
+
+        chunks = textChunks(new WordBoundaryChunker(), " ,", 100, 10);
+        assertThat(chunks, Matchers.contains(" ,"));
+    }
+
+    public void testSingleCharRepeated() {
+        var input = "a".repeat(32_000);
+        var chunks = textChunks(new WordBoundaryChunker(), input, 100, 10);
+        assertThat(chunks, Matchers.contains(input));
+    }
+
     public void testPunctuation() {
         int chunkSize = 1;
         var chunks = textChunks(new WordBoundaryChunker(), "Comma, separated", chunkSize, 0);

+ 16 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -1142,6 +1142,22 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         }
     }
 
+    public void testInferEmptyInput() throws IOException {
+        String modelId = "empty_input";
+        createPassThroughModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
+        request.setJsonEntity("""
+            {  "docs": [] }
+            """);
+
+        var inferenceResponse = client().performRequest(request);
+        assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}"));
+    }
+
     private void putModelDefinition(String modelId) throws IOException {
         putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
     }

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -132,6 +132,11 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         Response.Builder responseBuilder = Response.builder();
         TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
 
+        if (request.numberOfDocuments() == 0) {
+            listener.onResponse(responseBuilder.setId(request.getId()).build());
+            return;
+        }
+
         if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
             responseBuilder.setLicensed(true);
             doInfer(task, request, responseBuilder, parentTaskId, listener);