Procházet zdrojové kódy

[ML] Add sentence overlap option to the sentence chunking settings (#114461) (#114626)

David Kyle před 1 rokem
rodič
revize
58b835b9da

+ 1 - 1
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -240,7 +240,7 @@ public class TransportVersions {
     public static final TransportVersion SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS = def(8_764_00_0);
     public static final TransportVersion RETRIEVERS_TELEMETRY_ADDED = def(8_765_00_0);
     public static final TransportVersion ESQL_CACHED_STRING_SERIALIZATION = def(8_766_00_0);
-    public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDE1D = def(8_767_00_0);
+    public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0);
     public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
 
     /*

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java

@@ -10,7 +10,8 @@ package org.elasticsearch.xpack.inference.chunking;
 public enum ChunkingSettingsOptions {
     STRATEGY("strategy"),
     MAX_CHUNK_SIZE("max_chunk_size"),
-    OVERLAP("overlap");
+    OVERLAP("overlap"),
+    SENTENCE_OVERLAP("sentence_overlap");
 
     private final String chunkingSettingsOption;
 

+ 77 - 10
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java

@@ -34,6 +34,7 @@ public class SentenceBoundaryChunker implements Chunker {
     public SentenceBoundaryChunker() {
         sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT);
         wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
+
     }
 
     /**
@@ -46,7 +47,7 @@ public class SentenceBoundaryChunker implements Chunker {
     @Override
     public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
         if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) {
-            return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize);
+            return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0);
         } else {
             throw new IllegalArgumentException(
                 Strings.format(
@@ -64,7 +65,7 @@ public class SentenceBoundaryChunker implements Chunker {
      * @param maxNumberWordsPerChunk Maximum size of the chunk
      * @return The input text chunked
      */
-    public List<String> chunk(String input, int maxNumberWordsPerChunk) {
+    public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
         var chunks = new ArrayList<String>();
 
         sentenceIterator.setText(input);
@@ -75,24 +76,46 @@ public class SentenceBoundaryChunker implements Chunker {
         int sentenceStart = 0;
         int chunkWordCount = 0;
 
+        int wordsInPrecedingSentenceCount = 0;
+        int previousSentenceStart = 0;
+
         int boundary = sentenceIterator.next();
 
         while (boundary != BreakIterator.DONE) {
             int sentenceEnd = sentenceIterator.current();
-            int countWordsInSentence = countWords(sentenceStart, sentenceEnd);
+            int wordsInSentenceCount = countWords(sentenceStart, sentenceEnd);
 
-            if (chunkWordCount + countWordsInSentence > maxNumberWordsPerChunk) {
+            if (chunkWordCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
                 // over the max chunk size, roll back to the last sentence
 
+                int nextChunkWordCount = wordsInSentenceCount;
                 if (chunkWordCount > 0) {
                     // add a new chunk containing all the input up to this sentence
                     chunks.add(input.substring(chunkStart, chunkEnd));
-                    chunkStart = chunkEnd;
-                    chunkWordCount = countWordsInSentence; // the next chunk will contain this sentence
+
+                    if (includePrecedingSentence) {
+                        if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
+                            // cut the last sentence
+                            int numWordsToSkip = numWordsToSkipInPreviousSentence(wordsInPrecedingSentenceCount, maxNumberWordsPerChunk);
+
+                            chunkStart = skipWords(input, previousSentenceStart, numWordsToSkip);
+                            chunkWordCount = (wordsInPrecedingSentenceCount - numWordsToSkip) + wordsInSentenceCount;
+                        } else {
+                            chunkWordCount = wordsInPrecedingSentenceCount + wordsInSentenceCount;
+                            chunkStart = previousSentenceStart;
+                        }
+
+                        nextChunkWordCount = chunkWordCount;
+                    } else {
+                        chunkStart = chunkEnd;
+                        chunkWordCount = wordsInSentenceCount; // the next chunk will contain this sentence
+                    }
                 }
 
-                if (countWordsInSentence > maxNumberWordsPerChunk) {
-                    // This sentence is bigger than the max chunk size.
+                // Is the next chunk larger than max chunk size?
+                // If so split it
+                if (nextChunkWordCount > maxNumberWordsPerChunk) {
+                    // This sentence (and optional overlap) is bigger than the max chunk size.
                     // Split the sentence on the word boundary
                     var sentenceSplits = splitLongSentence(
                         input.substring(chunkStart, sentenceEnd),
@@ -113,7 +136,12 @@ public class SentenceBoundaryChunker implements Chunker {
                     chunkWordCount = sentenceSplits.get(i).wordCount();
                 }
             } else {
-                chunkWordCount += countWordsInSentence;
+                chunkWordCount += wordsInSentenceCount;
+            }
+
+            if (includePrecedingSentence) {
+                previousSentenceStart = sentenceStart;
+                wordsInPrecedingSentenceCount = wordsInSentenceCount;
             }
 
             sentenceStart = sentenceEnd;
@@ -133,6 +161,45 @@ public class SentenceBoundaryChunker implements Chunker {
         return new WordBoundaryChunker().chunkPositions(text, maxNumberOfWords, overlap);
     }
 
+    static int numWordsToSkipInPreviousSentence(int wordsInPrecedingSentenceCount, int maxNumberWordsPerChunk) {
+        var maxWordsInOverlap = maxWordsInOverlap(maxNumberWordsPerChunk);
+        if (wordsInPrecedingSentenceCount > maxWordsInOverlap) {
+            return wordsInPrecedingSentenceCount - maxWordsInOverlap;
+        } else {
+            return 0;
+        }
+    }
+
+    static int maxWordsInOverlap(int maxNumberWordsPerChunk) {
+        return Math.min(maxNumberWordsPerChunk / 2, 20);
+    }
+
+    private int skipWords(String input, int start, int numWords) {
+        var itr = BreakIterator.getWordInstance(Locale.ROOT);
+        itr.setText(input);
+        return skipWords(start, numWords, itr);
+    }
+
+    static int skipWords(int start, int numWords, BreakIterator wordIterator) {
+        wordIterator.preceding(start); // start of the current word
+
+        int boundary = wordIterator.current();
+        int wordCount = 0;
+        while (boundary != BreakIterator.DONE && wordCount < numWords) {
+            int wordStatus = wordIterator.getRuleStatus();
+            if (wordStatus != BreakIterator.WORD_NONE) {
+                wordCount++;
+            }
+            boundary = wordIterator.next();
+        }
+
+        if (boundary == BreakIterator.DONE) {
+            return wordIterator.last();
+        } else {
+            return boundary;
+        }
+    }
+
     private int countWords(int start, int end) {
         return countWords(start, end, this.wordIterator);
     }
@@ -157,6 +224,6 @@ public class SentenceBoundaryChunker implements Chunker {
     }
 
     private static int overlapForChunkSize(int chunkSize) {
-        return (chunkSize - 1) / 2;
+        return Math.min(20, (chunkSize - 1) / 2);
     }
 }

+ 30 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.ChunkingStrategy;
 import org.elasticsearch.inference.ModelConfigurations;
@@ -30,16 +31,25 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
     private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE;
     private static final Set<String> VALID_KEYS = Set.of(
         ChunkingSettingsOptions.STRATEGY.toString(),
-        ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString()
+        ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
+        ChunkingSettingsOptions.SENTENCE_OVERLAP.toString()
     );
+
+    private static int DEFAULT_OVERLAP = 0;
+
     protected final int maxChunkSize;
+    protected int sentenceOverlap = DEFAULT_OVERLAP;
 
-    public SentenceBoundaryChunkingSettings(Integer maxChunkSize) {
+    public SentenceBoundaryChunkingSettings(Integer maxChunkSize, @Nullable Integer sentenceOverlap) {
         this.maxChunkSize = maxChunkSize;
+        this.sentenceOverlap = sentenceOverlap == null ? DEFAULT_OVERLAP : sentenceOverlap;
     }
 
     public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException {
         maxChunkSize = in.readInt();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) {
+            sentenceOverlap = in.readVInt();
+        }
     }
 
     public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map) {
@@ -59,11 +69,24 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
             validationException
         );
 
+        Integer sentenceOverlap = ServiceUtils.extractOptionalPositiveInteger(
+            map,
+            ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
+            ModelConfigurations.CHUNKING_SETTINGS,
+            validationException
+        );
+
+        if (sentenceOverlap != null && sentenceOverlap > 1) {
+            validationException.addValidationError(
+                ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() + "[" + sentenceOverlap + "] must be either 0 or 1"
+            ); // todo better
+        }
+
         if (validationException.validationErrors().isEmpty() == false) {
             throw validationException;
         }
 
-        return new SentenceBoundaryChunkingSettings(maxChunkSize);
+        return new SentenceBoundaryChunkingSettings(maxChunkSize, sentenceOverlap);
     }
 
     @Override
@@ -72,6 +95,7 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
         {
             builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY);
             builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
+            builder.field(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), sentenceOverlap);
         }
         builder.endObject();
         return builder;
@@ -90,6 +114,9 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeInt(maxChunkSize);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) {
+            out.writeVInt(sentenceOverlap);
+        }
     }
 
     @Override

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java

@@ -52,7 +52,7 @@ public class WordBoundaryChunkingSettings implements ChunkingSettings {
         var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray();
         if (invalidSettings.length > 0) {
             validationException.addValidationError(
-                Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
+                Strings.format("Word based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
             );
         }
 

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java

@@ -56,7 +56,7 @@ public class ChunkingSettingsBuilderTests extends ESTestCase {
                 ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
                 maxChunkSize
             ),
-            new SentenceBoundaryChunkingSettings(maxChunkSize)
+            new SentenceBoundaryChunkingSettings(maxChunkSize, 1)
         );
     }
 }

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java

@@ -25,7 +25,7 @@ public class ChunkingSettingsTests extends ESTestCase {
                 return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2));
             }
             case SENTENCE -> {
-                return new SentenceBoundaryChunkingSettings(randomNonNegativeInt());
+                return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1);
             }
             default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]");
         }

+ 175 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java

@@ -13,19 +13,24 @@ import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.test.ESTestCase;
 import org.hamcrest.Matchers;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Locale;
 
 import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.endsWith;
+import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.startsWith;
 
 public class SentenceBoundaryChunkerTests extends ESTestCase {
 
     public void testChunkSplitLargeChunkSizes() {
         for (int maxWordsPerChunk : new int[] { 100, 200 }) {
             var chunker = new SentenceBoundaryChunker();
-            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk);
+            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false);
 
             int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk);
             assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks));
@@ -39,11 +44,94 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         }
     }
 
+    public void testChunkSplitLargeChunkSizes_withOverlap() {
+        boolean overlap = true;
+        for (int maxWordsPerChunk : new int[] { 70, 80, 100, 120, 150, 200 }) {
+            var chunker = new SentenceBoundaryChunker();
+            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, overlap);
+
+            int[] overlaps = chunkOverlaps(sentenceSizes(TEST_TEXT), maxWordsPerChunk, overlap);
+            assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(overlaps.length));
+
+            assertTrue(Character.isUpperCase(chunks.get(0).charAt(0)));
+
+            for (int i = 0; i < overlaps.length; i++) {
+                if (overlaps[i] == 0) {
+                    // start of a sentence
+                    assertTrue(Character.isUpperCase(chunks.get(i).charAt(0)));
+                } else {
+                    // The start of this chunk should contain some text from the end of the previous
+                    var previousChunk = chunks.get(i - 1);
+                    assertThat(chunks.get(i), containsString(previousChunk.substring(previousChunk.length() - 20)));
+                }
+            }
+
+            var trailingWhiteSpaceRemoved = chunks.get(0).strip();
+            var lastChar = trailingWhiteSpaceRemoved.charAt(trailingWhiteSpaceRemoved.length() - 1);
+            assertThat(lastChar, Matchers.is('.'));
+            trailingWhiteSpaceRemoved = chunks.get(chunks.size() - 1).strip();
+            lastChar = trailingWhiteSpaceRemoved.charAt(trailingWhiteSpaceRemoved.length() - 1);
+            assertThat(lastChar, Matchers.is('.'));
+        }
+    }
+
+    public void testWithOverlap_SentencesFitInChunks() {
+        int numChunks = 4;
+        int chunkSize = 100;
+
+        var sb = new StringBuilder();
+
+        int[] sentenceStartIndexes = new int[numChunks];
+        sentenceStartIndexes[0] = 0;
+
+        int numSentences = randomIntBetween(2, 5);
+        int sentenceIndex = 0;
+        int lastSentenceSize = 0;
+        int roughSentenceSize = (chunkSize / numSentences) - 1;
+        for (int j = 0; j < numSentences; j++) {
+            sb.append(makeSentence(roughSentenceSize, sentenceIndex++));
+            lastSentenceSize = roughSentenceSize;
+        }
+
+        for (int i = 1; i < numChunks; i++) {
+            sentenceStartIndexes[i] = sentenceIndex - 1;
+
+            roughSentenceSize = (chunkSize / numSentences) - 1;
+            int wordCount = lastSentenceSize;
+
+            while (wordCount + roughSentenceSize < chunkSize) {
+                sb.append(makeSentence(roughSentenceSize, sentenceIndex++));
+                lastSentenceSize = roughSentenceSize;
+                wordCount += roughSentenceSize;
+            }
+        }
+
+        var chunker = new SentenceBoundaryChunker();
+        var chunks = chunker.chunk(sb.toString(), chunkSize, true);
+        assertThat(chunks, hasSize(numChunks));
+        for (int i = 0; i < numChunks; i++) {
+            assertThat("num sentences " + numSentences, chunks.get(i), startsWith("SStart" + sentenceStartIndexes[i]));
+            assertThat("num sentences " + numSentences, chunks.get(i).trim(), endsWith("."));
+        }
+    }
+
+    private String makeSentence(int numWords, int sentenceIndex) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("SStart").append(sentenceIndex).append(' ');
+        for (int i = 1; i < numWords - 1; i++) {
+            sb.append(i).append(' ');
+        }
+        sb.append(numWords - 1).append(". ");
+        return sb.toString();
+    }
+
     public void testChunk_ChunkSizeLargerThanText() {
         int maxWordsPerChunk = 500;
         var chunker = new SentenceBoundaryChunker();
-        var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk);
+        var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false);
+        assertEquals(chunks.get(0), TEST_TEXT);
 
+        chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true);
         assertEquals(chunks.get(0), TEST_TEXT);
     }
 
@@ -54,7 +142,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         for (int i = 0; i < chunkSizes.length; i++) {
             int maxWordsPerChunk = chunkSizes[i];
             var chunker = new SentenceBoundaryChunker();
-            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk);
+            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false);
 
             assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(expectedNumberOFChunks[i]));
             for (var chunk : chunks) {
@@ -76,6 +164,48 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         }
     }
 
+    public void testChunkSplit_SentencesLongerThanChunkSize_WithOverlap() {
+        var chunkSizes = new int[] { 10, 30, 50 };
+
+        // Chunk sizes are shorter the sentences most of the sentences will be split.
+        for (int i = 0; i < chunkSizes.length; i++) {
+            int maxWordsPerChunk = chunkSizes[i];
+            var chunker = new SentenceBoundaryChunker();
+            var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true);
+            assertThat(chunks.get(0), containsString("Word segmentation is the problem of dividing"));
+            assertThat(chunks.get(chunks.size() - 1), containsString(", with solidification being a stronger norm."));
+        }
+    }
+
+    public void testShortLongShortSentences_WithOverlap() {
+        int maxWordsPerChunk = 40;
+        var sb = new StringBuilder();
+        int[] sentenceLengths = new int[] { 15, 30, 20, 5 };
+        for (int l = 0; l < sentenceLengths.length; l++) {
+            sb.append("SStart").append(l).append(" ");
+            for (int i = 1; i < sentenceLengths[l] - 1; i++) {
+                sb.append(i).append(' ');
+            }
+            sb.append(sentenceLengths[l] - 1).append(". ");
+        }
+
+        var chunker = new SentenceBoundaryChunker();
+        var chunks = chunker.chunk(sb.toString(), maxWordsPerChunk, true);
+        assertThat(chunks, hasSize(5));
+        assertTrue(chunks.get(0).trim().startsWith("SStart0"));  // Entire sentence
+        assertTrue(chunks.get(0).trim().endsWith("."));  // Entire sentence
+
+        assertTrue(chunks.get(1).trim().startsWith("SStart0"));  // contains previous sentence
+        assertFalse(chunks.get(1).trim().endsWith("."));   // not a full sentence(s)
+
+        assertTrue(chunks.get(2).trim().endsWith("."));
+        assertTrue(chunks.get(3).trim().endsWith("."));
+
+        assertTrue(chunks.get(4).trim().startsWith("SStart2"));  // contains previous sentence
+        assertThat(chunks.get(4), containsString("SStart3"));   // last chunk contains 2 sentences
+        assertTrue(chunks.get(4).trim().endsWith("."));   // full sentence(s)
+    }
+
     public void testCountWords() {
         // Test word count matches the whitespace separated word count.
         var splitByWhiteSpaceSentenceSizes = sentenceSizes(TEST_TEXT);
@@ -102,6 +232,30 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
         assertEquals(BreakIterator.DONE, sentenceIterator.next());
     }
 
+    public void testSkipWords() {
+        int numWords = 50;
+        StringBuilder sb = new StringBuilder();
+        for (int i = 0; i < numWords; i++) {
+            sb.append("word").append(i).append(" ");
+        }
+        var text = sb.toString();
+
+        var wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
+        wordIterator.setText(text);
+
+        int start = 0;
+        int pos = SentenceBoundaryChunker.skipWords(start, 3, wordIterator);
+        assertThat(text.substring(pos), startsWith("word3 "));
+        pos = SentenceBoundaryChunker.skipWords(pos + 1, 1, wordIterator);
+        assertThat(text.substring(pos), startsWith("word4 "));
+        pos = SentenceBoundaryChunker.skipWords(pos + 1, 5, wordIterator);
+        assertThat(text.substring(pos), startsWith("word9 "));
+
+        // past the end of the input
+        pos = SentenceBoundaryChunker.skipWords(0, numWords + 10, wordIterator);
+        assertThat(pos, greaterThan(0));
+    }
+
     public void testCountWords_short() {
         // Test word count matches the whitespace separated word count.
         var text = "This is a short sentence. Followed by another.";
@@ -148,7 +302,7 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
     public void testChunkSplitLargeChunkSizesWithChunkingSettings() {
         for (int maxWordsPerChunk : new int[] { 100, 200 }) {
             var chunker = new SentenceBoundaryChunker();
-            SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk);
+            SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk, 0);
             var chunks = chunker.chunk(TEST_TEXT, chunkingSettings);
 
             int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk);
@@ -182,16 +336,30 @@ public class SentenceBoundaryChunkerTests extends ESTestCase {
     }
 
     private int expectedNumberOfChunks(int[] sentenceLengths, int maxWordsPerChunk) {
-        int numChunks = 1;
+        return chunkOverlaps(sentenceLengths, maxWordsPerChunk, false).length;
+    }
+
+    private int[] chunkOverlaps(int[] sentenceLengths, int maxWordsPerChunk, boolean includeSingleSentenceOverlap) {
+        int maxOverlap = SentenceBoundaryChunker.maxWordsInOverlap(maxWordsPerChunk);
+
+        var overlaps = new ArrayList<Integer>();
+        overlaps.add(0);
         int runningWordCount = 0;
         for (int i = 0; i < sentenceLengths.length; i++) {
             if (runningWordCount + sentenceLengths[i] > maxWordsPerChunk) {
-                numChunks++;
                 runningWordCount = sentenceLengths[i];
+                if (includeSingleSentenceOverlap && i > 0) {
+                    // include what is carried over from the previous
+                    int overlap = Math.min(maxOverlap, sentenceLengths[i - 1]);
+                    overlaps.add(overlap);
+                    runningWordCount += overlap;
+                } else {
+                    overlaps.add(0);
+                }
             } else {
                 runningWordCount += sentenceLengths[i];
             }
         }
-        return numChunks;
+        return overlaps.stream().mapToInt(Integer::intValue).toArray();
     }
 }

+ 2 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java

@@ -59,13 +59,12 @@ public class SentenceBoundaryChunkingSettingsTests extends AbstractWireSerializi
 
     @Override
     protected SentenceBoundaryChunkingSettings createTestInstance() {
-        return new SentenceBoundaryChunkingSettings(randomNonNegativeInt());
+        return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1);
     }
 
     @Override
     protected SentenceBoundaryChunkingSettings mutateInstance(SentenceBoundaryChunkingSettings instance) throws IOException {
         var chunkSize = randomValueOtherThan(instance.maxChunkSize, ESTestCase::randomNonNegativeInt);
-
-        return new SentenceBoundaryChunkingSettings(chunkSize);
+        return new SentenceBoundaryChunkingSettings(chunkSize, instance.sentenceOverlap);
     }
 }

+ 1 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java

@@ -54,9 +54,6 @@ public class WordBoundaryChunkerTests extends ESTestCase {
             + " خليفہ المومنين يا خليفہ المسلمين يا صحابی يا رضي الله عنه چئي۔ (ب) آنحضور ﷺ جي گھروارين کان علاوه ڪنھن کي ام المومنين "
             + "چئي۔ (ج) آنحضور ﷺ جي خاندان جي اھل بيت کان علاوہڍه ڪنھن کي اھل بيت چئي۔ (د) پنھنجي عبادت گاھ کي مسجد چئي۔" };
 
-    private static final int DEFAULT_MAX_CHUNK_SIZE = 250;
-    private static final int DEFAULT_OVERLAP = 100;
-
     public static int NUM_WORDS_IN_TEST_TEXT;
     static {
         var wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
@@ -139,7 +136,7 @@ public class WordBoundaryChunkerTests extends ESTestCase {
     }
 
     public void testInvalidChunkingSettingsProvided() {
-        ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt());
+        ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), 0);
         assertThrows(IllegalArgumentException.class, () -> { new WordBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); });
     }