Browse Source

[ML] Refactor the Chunker classes to return offsets (#117977) (#118279)

David Kyle 10 months ago
parent
commit
88a724a293

+ 3 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java

@@ -12,5 +12,7 @@ import org.elasticsearch.inference.ChunkingSettings;
 import java.util.List;
 
 public interface Chunker {
-    List<String> chunk(String input, ChunkingSettings chunkingSettings);
+    record ChunkOffset(int start, int end) {};
+
+    List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings);
 }

+ 35 - 20
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

@@ -68,7 +68,7 @@ public class EmbeddingRequestChunker {
     private final EmbeddingType embeddingType;
     private final ChunkingSettings chunkingSettings;
 
-    private List<List<String>> chunkedInputs;
+    private List<ChunkOffsetsAndInput> chunkedOffsets;
     private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
     private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
     private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
@@ -109,7 +109,7 @@ public class EmbeddingRequestChunker {
     }
 
     private void splitIntoBatchedRequests(List<String> inputs) {
-        Function<String, List<String>> chunkFunction;
+        Function<String, List<Chunker.ChunkOffset>> chunkFunction;
         if (chunkingSettings != null) {
             var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
             chunkFunction = input -> chunker.chunk(input, chunkingSettings);
@@ -118,7 +118,7 @@ public class EmbeddingRequestChunker {
             chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap);
         }
 
-        chunkedInputs = new ArrayList<>(inputs.size());
+        chunkedOffsets = new ArrayList<>(inputs.size());
         switch (embeddingType) {
             case FLOAT -> floatResults = new ArrayList<>(inputs.size());
             case BYTE -> byteResults = new ArrayList<>(inputs.size());
@@ -128,18 +128,19 @@ public class EmbeddingRequestChunker {
 
         for (int i = 0; i < inputs.size(); i++) {
             var chunks = chunkFunction.apply(inputs.get(i));
-            int numberOfSubBatches = addToBatches(chunks, i);
+            var offSetsAndInput = new ChunkOffsetsAndInput(chunks, inputs.get(i));
+            int numberOfSubBatches = addToBatches(offSetsAndInput, i);
             // size the results array with the expected number of request/responses
             switch (embeddingType) {
                 case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
                 case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
                 case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
             }
-            chunkedInputs.add(chunks);
+            chunkedOffsets.add(offSetsAndInput);
         }
     }
 
-    private int addToBatches(List<String> chunks, int inputIndex) {
+    private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) {
         BatchRequest lastBatch;
         if (batchedRequests.isEmpty()) {
             lastBatch = new BatchRequest(new ArrayList<>());
@@ -157,16 +158,24 @@ public class EmbeddingRequestChunker {
 
         if (freeSpace > 0) {
             // use any free space in the previous batch before creating new batches
-            int toAdd = Math.min(freeSpace, chunks.size());
-            lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)));
+            int toAdd = Math.min(freeSpace, chunk.offsets().size());
+            lastBatch.addSubBatch(
+                new SubBatch(
+                    new ChunkOffsetsAndInput(chunk.offsets().subList(0, toAdd), chunk.input()),
+                    new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)
+                )
+            );
         }
 
         int start = freeSpace;
-        while (start < chunks.size()) {
-            int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start);
+        while (start < chunk.offsets().size()) {
+            int toAdd = Math.min(maxNumberOfInputsPerBatch, chunk.offsets().size() - start);
             var batch = new BatchRequest(new ArrayList<>());
             batch.addSubBatch(
-                new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))
+                new SubBatch(
+                    new ChunkOffsetsAndInput(chunk.offsets().subList(start, start + toAdd), chunk.input()),
+                    new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)
+                )
             );
             batchedRequests.add(batch);
             start += toAdd;
@@ -333,8 +342,8 @@ public class EmbeddingRequestChunker {
         }
 
         private void sendResponse() {
-            var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedInputs.size());
-            for (int i = 0; i < chunkedInputs.size(); i++) {
+            var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedOffsets.size());
+            for (int i = 0; i < chunkedOffsets.size(); i++) {
                 if (errors.get(i) != null) {
                     response.add(errors.get(i));
                 } else {
@@ -348,9 +357,9 @@ public class EmbeddingRequestChunker {
 
     private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) {
         return switch (embeddingType) {
-            case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex));
-            case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex));
-            case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex));
+            case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex));
+            case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex));
+            case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex));
         };
     }
 
@@ -428,7 +437,7 @@ public class EmbeddingRequestChunker {
         }
 
         public List<String> inputs() {
-            return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList());
+            return subBatches.stream().flatMap(s -> s.requests().toChunkText().stream()).collect(Collectors.toList());
         }
     }
 
@@ -441,9 +450,15 @@ public class EmbeddingRequestChunker {
      */
     record SubBatchPositionsAndCount(int inputIndex, int chunkIndex, int embeddingCount) {}
 
-    record SubBatch(List<String> requests, SubBatchPositionsAndCount positions) {
-        public int size() {
-            return requests.size();
+    record SubBatch(ChunkOffsetsAndInput requests, SubBatchPositionsAndCount positions) {
+        int size() {
+            return requests.offsets().size();
+        }
+    }
+
+    record ChunkOffsetsAndInput(List<Chunker.ChunkOffset> offsets, String input) {
+        List<String> toChunkText() {
+            return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList());
         }
     }
 }

+ 12 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java

@@ -34,7 +34,6 @@ public class SentenceBoundaryChunker implements Chunker {
     public SentenceBoundaryChunker() {
         sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT);
         wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
-
     }
 
     /**
@@ -45,7 +44,7 @@ public class SentenceBoundaryChunker implements Chunker {
      * @return The input text chunked
      */
     @Override
-    public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
+    public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
         if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) {
             return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0);
         } else {
@@ -65,8 +64,8 @@ public class SentenceBoundaryChunker implements Chunker {
      * @param maxNumberWordsPerChunk Maximum size of the chunk
      * @return The input text chunked
      */
-    public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
-        var chunks = new ArrayList<String>();
+    public List<ChunkOffset> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
+        var chunks = new ArrayList<ChunkOffset>();
 
         sentenceIterator.setText(input);
         wordIterator.setText(input);
@@ -91,7 +90,7 @@ public class SentenceBoundaryChunker implements Chunker {
                 int nextChunkWordCount = wordsInSentenceCount;
                 if (chunkWordCount > 0) {
                     // add a new chunk containing all the input up to this sentence
-                    chunks.add(input.substring(chunkStart, chunkEnd));
+                    chunks.add(new ChunkOffset(chunkStart, chunkEnd));
 
                     if (includePrecedingSentence) {
                         if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
@@ -127,12 +126,17 @@ public class SentenceBoundaryChunker implements Chunker {
                     for (; i < sentenceSplits.size() - 1; i++) {
                         // Because the substring was passed to splitLongSentence()
                         // the returned positions need to be offset by chunkStart
-                        chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end()));
+                        chunks.add(
+                            new ChunkOffset(
+                                chunkStart + sentenceSplits.get(i).offsets().start(),
+                                chunkStart + sentenceSplits.get(i).offsets().end()
+                            )
+                        );
                     }
                     // The final split is partially filled.
                     // Set the next chunk start to the beginning of the
                     // final split of the long sentence.
-                    chunkStart = chunkStart + sentenceSplits.get(i).start();  // start pos needs to be offset by chunkStart
+                    chunkStart = chunkStart + sentenceSplits.get(i).offsets().start();  // start pos needs to be offset by chunkStart
                     chunkWordCount = sentenceSplits.get(i).wordCount();
                 }
             } else {
@@ -151,7 +155,7 @@ public class SentenceBoundaryChunker implements Chunker {
         }
 
         if (chunkWordCount > 0) {
-            chunks.add(input.substring(chunkStart));
+            chunks.add(new ChunkOffset(chunkStart, input.length()));
         }
 
         return chunks;

+ 7 - 15
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java

@@ -15,6 +15,7 @@ import org.elasticsearch.inference.ChunkingSettings;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Locale;
+import java.util.stream.Collectors;
 
 /**
  * Breaks text into smaller strings or chunks on Word boundaries.
@@ -35,7 +36,7 @@ public class WordBoundaryChunker implements Chunker {
         wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
     }
 
-    record ChunkPosition(int start, int end, int wordCount) {}
+    record ChunkPosition(ChunkOffset offsets, int wordCount) {}
 
     /**
      * Break the input text into small chunks as dictated
@@ -45,7 +46,7 @@ public class WordBoundaryChunker implements Chunker {
      * @return List of chunked text
      */
     @Override
-    public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
+    public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
         if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) {
             return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap);
         } else {
@@ -64,18 +65,9 @@ public class WordBoundaryChunker implements Chunker {
      *                Can be 0 but must be non-negative.
      * @return List of chunked text
      */
-    public List<String> chunk(String input, int chunkSize, int overlap) {
-
-        if (input.isEmpty()) {
-            return List.of("");
-        }
-
+    public List<ChunkOffset> chunk(String input, int chunkSize, int overlap) {
         var chunkPositions = chunkPositions(input, chunkSize, overlap);
-        var chunks = new ArrayList<String>(chunkPositions.size());
-        for (var pos : chunkPositions) {
-            chunks.add(input.substring(pos.start, pos.end));
-        }
-        return chunks;
+        return chunkPositions.stream().map(ChunkPosition::offsets).collect(Collectors.toList());
     }
 
     /**
@@ -127,7 +119,7 @@ public class WordBoundaryChunker implements Chunker {
                 wordsSinceStartWindowWasMarked++;
 
                 if (wordsInChunkCountIncludingOverlap >= chunkSize) {
-                    chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap));
+                    chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, boundary), wordsInChunkCountIncludingOverlap));
                     wordsInChunkCountIncludingOverlap = overlap;
 
                     if (overlap == 0) {
@@ -149,7 +141,7 @@ public class WordBoundaryChunker implements Chunker {
         // if it ends on a boundary than the count should equal overlap in which case
         // we can ignore it, unless this is the first chunk in which case we want to add it
         if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) {
-            chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap));
+            chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, input.length()), wordsInChunkCountIncludingOverlap));
         }
 
         return chunkPositions;

+ 12 - 12
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

@@ -62,7 +62,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
         var subBatches = batches.get(0).batch().subBatches();
         for (int i = 0; i < inputs.size(); i++) {
             var subBatch = subBatches.get(i);
-            assertThat(subBatch.requests(), contains(inputs.get(i)));
+            assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
             assertEquals(0, subBatch.positions().chunkIndex());
             assertEquals(i, subBatch.positions().inputIndex());
             assertEquals(1, subBatch.positions().embeddingCount());
@@ -102,7 +102,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
         var subBatches = batches.get(0).batch().subBatches();
         for (int i = 0; i < batches.size(); i++) {
             var subBatch = subBatches.get(i);
-            assertThat(subBatch.requests(), contains(inputs.get(i)));
+            assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
             assertEquals(0, subBatch.positions().chunkIndex());
             assertEquals(inputIndex, subBatch.positions().inputIndex());
             assertEquals(1, subBatch.positions().embeddingCount());
@@ -146,7 +146,7 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
         var subBatches = batches.get(0).batch().subBatches();
         for (int i = 0; i < batches.size(); i++) {
             var subBatch = subBatches.get(i);
-            assertThat(subBatch.requests(), contains(inputs.get(i)));
+            assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
             assertEquals(0, subBatch.positions().chunkIndex());
             assertEquals(inputIndex, subBatch.positions().inputIndex());
             assertEquals(1, subBatch.positions().embeddingCount());
@@ -184,17 +184,17 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
                 assertEquals(0, subBatch.positions().inputIndex());
                 assertEquals(0, subBatch.positions().chunkIndex());
                 assertEquals(1, subBatch.positions().embeddingCount());
-                assertThat(subBatch.requests(), contains("1st small"));
+                assertThat(subBatch.requests().toChunkText(), contains("1st small"));
             }
             {
                 var subBatch = batch.subBatches().get(1);
                 assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
                 assertEquals(0, subBatch.positions().chunkIndex());  // 1st part of the 2nd input
                 assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks
-                assertThat(subBatch.requests().get(0), startsWith("passage_input0 "));
-                assertThat(subBatch.requests().get(1), startsWith(" passage_input20 "));
-                assertThat(subBatch.requests().get(2), startsWith(" passage_input40 "));
-                assertThat(subBatch.requests().get(3), startsWith(" passage_input60 "));
+                assertThat(subBatch.requests().toChunkText().get(0), startsWith("passage_input0 "));
+                assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input20 "));
+                assertThat(subBatch.requests().toChunkText().get(2), startsWith(" passage_input40 "));
+                assertThat(subBatch.requests().toChunkText().get(3), startsWith(" passage_input60 "));
             }
         }
         {
@@ -207,22 +207,22 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
                 assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
                 assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input
                 assertEquals(2, subBatch.positions().embeddingCount());
-                assertThat(subBatch.requests().get(0), startsWith(" passage_input80 "));
-                assertThat(subBatch.requests().get(1), startsWith(" passage_input100 "));
+                assertThat(subBatch.requests().toChunkText().get(0), startsWith(" passage_input80 "));
+                assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input100 "));
             }
             {
                 var subBatch = batch.subBatches().get(1);
                 assertEquals(2, subBatch.positions().inputIndex()); // 3rd input
                 assertEquals(0, subBatch.positions().chunkIndex());  // 1st and only part
                 assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
-                assertThat(subBatch.requests(), contains("2nd small"));
+                assertThat(subBatch.requests().toChunkText(), contains("2nd small"));
             }
             {
                 var subBatch = batch.subBatches().get(2);
                 assertEquals(3, subBatch.positions().inputIndex());  // 4th input
                 assertEquals(0, subBatch.positions().chunkIndex());  // 1st and only part
                 assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
-                assertThat(subBatch.requests(), contains("3rd small"));
+                assertThat(subBatch.requests().toChunkText(), contains("3rd small"));
             }
         }
     }

+ 25 - 9
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java

@@ -15,7 +15,9 @@ import org.hamcrest.Matchers;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Locale;
+import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;
 import static org.hamcrest.Matchers.containsString;
@@ -27,10 +29,24 @@ import static org.hamcrest.Matchers.startsWith;
 
 public class SentenceBoundaryChunkerTests extends ESTestCase {
 
+    /**
+     * Utility method for testing.
+     * Use the chunk functions that return offsets where possible
+     */
+    private List<String> textChunks(
+        SentenceBoundaryChunker chunker,
+        String input,
+        int maxNumberWordsPerChunk,
+        boolean includePrecedingSentence
+    ) {
+        var chunkPositions = chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence);
+        return chunkPositions.stream().map(offset -> input.substring(offset.start(), offset.end())).collect(Collectors.toList());
+    }
+
     public void testChunkSplitLargeChunkSizes() {
         for (int maxWordsPerChunk : new int[] { 100, 200 }) {
             var chunker = new SentenceBoundaryChunker();
-            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false);
+            var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false);
 
             int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk);
             assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks));
@@ -48,7 +64,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         boolean overlap = true;
         for (int maxWordsPerChunk : new int[] { 70, 80, 100, 120, 150, 200 }) {
             var chunker = new SentenceBoundaryChunker();
-            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, overlap);
+            var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, overlap);
 
             int[] overlaps = chunkOverlaps(sentenceSizes(TEST_TEXT), maxWordsPerChunk, overlap);
             assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(overlaps.length));
@@ -107,7 +123,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         }
 
         var chunker = new SentenceBoundaryChunker();
-        var chunks = chunker.chunk(sb.toString(), chunkSize, true);
+        var chunks = textChunks(chunker, sb.toString(), chunkSize, true);
         assertThat(chunks, hasSize(numChunks));
         for (int i = 0; i < numChunks; i++) {
             assertThat("num sentences " + numSentences, chunks.get(i), startsWith("SStart" + sentenceStartIndexes[i]));
@@ -128,10 +144,10 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
     public void testChunk_ChunkSizeLargerThanText() {
         int maxWordsPerChunk = 500;
         var chunker = new SentenceBoundaryChunker();
-        var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false);
+        var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false);
         assertEquals(chunks.get(0), TEST_TEXT);
 
-        chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true);
+        chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true);
         assertEquals(chunks.get(0), TEST_TEXT);
     }
 
@@ -142,7 +158,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         for (int i = 0; i < chunkSizes.length; i++) {
             int maxWordsPerChunk = chunkSizes[i];
             var chunker = new SentenceBoundaryChunker();
-            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false);
+            var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false);
 
             assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(expectedNumberOFChunks[i]));
             for (var chunk : chunks) {
@@ -171,7 +187,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         for (int i = 0; i < chunkSizes.length; i++) {
             int maxWordsPerChunk = chunkSizes[i];
             var chunker = new SentenceBoundaryChunker();
-            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true);
+            var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true);
             assertThat(chunks.get(0), containsString("Word segmentation is the problem of dividing"));
             assertThat(chunks.get(chunks.size() - 1), containsString(", with solidification being a stronger norm."));
         }
@@ -190,7 +206,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         }
 
         var chunker = new SentenceBoundaryChunker();
-        var chunks = chunker.chunk(sb.toString(), maxWordsPerChunk, true);
+        var chunks = textChunks(chunker, sb.toString(), maxWordsPerChunk, true);
         assertThat(chunks, hasSize(5));
         assertTrue(chunks.get(0).trim().startsWith("SStart0"));  // Entire sentence
         assertTrue(chunks.get(0).trim().endsWith("."));  // Entire sentence
@@ -303,7 +319,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         for (int maxWordsPerChunk : new int[] { 100, 200 }) {
             var chunker = new SentenceBoundaryChunker();
             SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk, 0);
-            var chunks = chunker.chunk(TEST_TEXT, chunkingSettings);
+            var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false);
 
             int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk);
             assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks));

+ 25 - 11
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.test.ESTestCase;
 
 import java.util.List;
 import java.util.Locale;
+import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
@@ -65,9 +66,22 @@ public class WordBoundaryChunkerTests extends ESTestCase {
         NUM_WORDS_IN_TEST_TEXT = wordCount;
     }
 
+    /**
+     * Utility method for testing.
+     * Use the chunk functions that return offsets where possible
+     */
+    List<String> textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) {
+        if (input.isEmpty()) {
+            return List.of("");
+        }
+
+        var chunkPositions = chunker.chunk(input, chunkSize, overlap);
+        return chunkPositions.stream().map(p -> input.substring(p.start(), p.end())).collect(Collectors.toList());
+    }
+
     public void testSingleSplit() {
         var chunker = new WordBoundaryChunker();
-        var chunks = chunker.chunk(TEST_TEXT, 10_000, 0);
+        var chunks = textChunks(chunker, TEST_TEXT, 10_000, 0);
         assertThat(chunks, hasSize(1));
         assertEquals(TEST_TEXT, chunks.get(0));
     }
@@ -168,11 +182,11 @@ public class WordBoundaryChunkerTests extends ESTestCase {
         }
         var whiteSpacedText = input.toString().stripTrailing();
 
-        var chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 20, 10);
+        var chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 20, 10);
         assertChunkContents(chunks, numWords, 20, 10);
-        chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 10, 4);
+        chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 10, 4);
         assertChunkContents(chunks, numWords, 10, 4);
-        chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 15, 3);
+        chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 15, 3);
         assertChunkContents(chunks, numWords, 15, 3);
     }
 
@@ -217,28 +231,28 @@ public class WordBoundaryChunkerTests extends ESTestCase {
     }
 
     public void testEmptyString() {
-        var chunks = new WordBoundaryChunker().chunk("", 10, 5);
-        assertThat(chunks, contains(""));
+        var chunks = textChunks(new WordBoundaryChunker(), "", 10, 5);
+        assertThat(chunks.toString(), chunks, contains(""));
     }
 
     public void testWhitespace() {
-        var chunks = new WordBoundaryChunker().chunk(" ", 10, 5);
+        var chunks = textChunks(new WordBoundaryChunker(), " ", 10, 5);
         assertThat(chunks, contains(" "));
     }
 
     public void testPunctuation() {
         int chunkSize = 1;
-        var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);
+        var chunks = textChunks(new WordBoundaryChunker(), "Comma, separated", chunkSize, 0);
         assertThat(chunks, contains("Comma", ", separated"));
 
-        chunks = new WordBoundaryChunker().chunk("Mme. Thénardier", chunkSize, 0);
+        chunks = textChunks(new WordBoundaryChunker(), "Mme. Thénardier", chunkSize, 0);
         assertThat(chunks, contains("Mme", ". Thénardier"));
 
-        chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0);
+        chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0);
         assertThat(chunks, contains("Won't", " you", " chunk"));
 
         chunkSize = 10;
-        chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0);
+        chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0);
         assertThat(chunks, contains("Won't you chunk"));
     }