Browse Source

Text similarity reranker chunks and scores snippets (#133576)

Kathleen DeRusso 1 month ago
parent
commit
436ec11ce2
17 changed files with 640 additions and 294 deletions
  1. 5 0
      docs/changelog/133576.yaml
  2. 2 0
      server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java
  3. 1 0
      x-pack/plugin/core/src/main/java/module-info.java
  4. 98 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorer.java
  5. 95 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorerTests.java
  6. 19 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java
  7. 21 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java
  8. 21 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java
  9. 100 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/ChunkScorerConfig.java
  10. 0 88
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java
  11. 15 73
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java
  12. 6 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java
  13. 37 30
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java
  14. 35 54
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java
  15. 8 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java
  16. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java
  17. 175 33
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

+ 5 - 0
docs/changelog/133576.yaml

@@ -0,0 +1,5 @@
+pr: 133576
+summary: Text similarity reranker chunks and scores snippets
+area: Relevance
+type: enhancement
+issues: []

+ 2 - 0
server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java

@@ -24,4 +24,6 @@ public interface ChunkingSettings extends ToXContentObject, VersionedNamedWritea
      * @return The max chunk size specified, or null if not specified
      */
     Integer maxChunkSize();
+
+    default void validate() {}
 }

+ 1 - 0
x-pack/plugin/core/src/main/java/module-info.java

@@ -234,6 +234,7 @@ module org.elasticsearch.xcore {
     exports org.elasticsearch.xpack.core.watcher.watch;
     exports org.elasticsearch.xpack.core.watcher;
     exports org.elasticsearch.xpack.core.security.authc.apikey;
+    exports org.elasticsearch.xpack.core.common.chunks;
 
     provides org.elasticsearch.action.admin.cluster.node.info.ComponentVersionNumber
         with

+ 98 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorer.java

@@ -0,0 +1,98 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.common.chunks;
+
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.TextField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.ByteBuffersDirectory;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.QueryBuilder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Utility class for scoring pre-determined chunks using an in-memory Lucene index.
+ */
+public class MemoryIndexChunkScorer {
+
+    private static final String CONTENT_FIELD = "content";
+
+    private final StandardAnalyzer analyzer;
+
+    public MemoryIndexChunkScorer() {
+        // TODO: Allow analyzer to be customizable and/or read from the field mapping
+        this.analyzer = new StandardAnalyzer();
+    }
+
+    /**
+     * Creates an in-memory index of chunks, or chunks, returns ordered, scored list.
+     *
+     * @param chunks the list of text chunks to score
+     * @param inferenceText the query text to compare against
+     * @param maxResults maximum number of results to return
+     * @return list of scored chunks ordered by relevance
+     * @throws IOException on failure scoring chunks
+     */
+    public List<ScoredChunk> scoreChunks(List<String> chunks, String inferenceText, int maxResults) throws IOException {
+        if (chunks == null || chunks.isEmpty() || inferenceText == null || inferenceText.trim().isEmpty()) {
+            return new ArrayList<>();
+        }
+
+        try (Directory directory = new ByteBuffersDirectory()) {
+            IndexWriterConfig config = new IndexWriterConfig(analyzer);
+            try (IndexWriter writer = new IndexWriter(directory, config)) {
+                for (String chunk : chunks) {
+                    Document doc = new Document();
+                    doc.add(new TextField(CONTENT_FIELD, chunk, Field.Store.YES));
+                    writer.addDocument(doc);
+                }
+                writer.commit();
+            }
+
+            try (DirectoryReader reader = DirectoryReader.open(directory)) {
+                IndexSearcher searcher = new IndexSearcher(reader);
+
+                org.apache.lucene.util.QueryBuilder qb = new QueryBuilder(analyzer);
+                Query query = qb.createBooleanQuery(CONTENT_FIELD, inferenceText, BooleanClause.Occur.SHOULD);
+                int numResults = Math.min(maxResults, chunks.size());
+                TopDocs topDocs = searcher.search(query, numResults);
+
+                List<ScoredChunk> scoredChunks = new ArrayList<>();
+                for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+                    Document doc = reader.storedFields().document(scoreDoc.doc);
+                    String content = doc.get(CONTENT_FIELD);
+                    scoredChunks.add(new ScoredChunk(content, scoreDoc.score));
+                }
+
+                // It's possible that no chunks were scorable (for example, a semantic match that does not have a lexical match).
+                // In this case, we'll return the first N chunks with a score of 0.
+                // TODO: consider parameterizing this
+                return scoredChunks.isEmpty() == false
+                    ? scoredChunks
+                    : chunks.subList(0, Math.min(maxResults, chunks.size())).stream().map(c -> new ScoredChunk(c, 0.0f)).toList();
+            }
+        }
+    }
+
+    /**
+     * Represents a chunk with its relevance score.
+     */
+    public record ScoredChunk(String content, float score) {}
+}

+ 95 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/chunks/MemoryIndexChunkScorerTests.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.common.chunks;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+
+public class MemoryIndexChunkScorerTests extends ESTestCase {
+
+    private static final List<String> CHUNKS = Arrays.asList(
+        "Cats like to sleep all day and play with mice",
+        "Dogs are loyal companions and great pets",
+        "The weather today is very sunny and warm",
+        "Dogs love to play with toys and go for walks",
+        "Elasticsearch is a great search engine"
+    );
+
+    public void testScoreChunks() throws IOException {
+        MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer();
+
+        String inferenceText = "dogs play walk";
+        int maxResults = 3;
+
+        List<MemoryIndexChunkScorer.ScoredChunk> scoredChunks = scorer.scoreChunks(CHUNKS, inferenceText, maxResults);
+
+        assertEquals(maxResults, scoredChunks.size());
+
+        // The chunks about dogs should score highest, followed by the chunk about cats
+        MemoryIndexChunkScorer.ScoredChunk chunk = scoredChunks.getFirst();
+        assertTrue(chunk.content().equalsIgnoreCase("Dogs love to play with toys and go for walks"));
+        assertThat(chunk.score(), greaterThan(0f));
+
+        chunk = scoredChunks.get(1);
+        assertTrue(chunk.content().equalsIgnoreCase("Dogs are loyal companions and great pets"));
+        assertThat(chunk.score(), greaterThan(0f));
+
+        chunk = scoredChunks.get(2);
+        assertTrue(chunk.content().equalsIgnoreCase("Cats like to sleep all day and play with mice"));
+        assertThat(chunk.score(), greaterThan(0f));
+
+        // Scores should be in descending order
+        for (int i = 1; i < scoredChunks.size(); i++) {
+            assertTrue(scoredChunks.get(i - 1).score() >= scoredChunks.get(i).score());
+        }
+    }
+
+    public void testEmptyChunks() throws IOException {
+
+        int maxResults = 3;
+
+        MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer();
+
+        // Zero results
+        List<MemoryIndexChunkScorer.ScoredChunk> scoredChunks = scorer.scoreChunks(CHUNKS, "puggles", maxResults);
+        assertEquals(maxResults, scoredChunks.size());
+
+        // There were no results so we return the first N chunks in order
+        MemoryIndexChunkScorer.ScoredChunk chunk = scoredChunks.getFirst();
+        assertTrue(chunk.content().equalsIgnoreCase("Cats like to sleep all day and play with mice"));
+        assertThat(chunk.score(), equalTo(0f));
+
+        chunk = scoredChunks.get(1);
+        assertTrue(chunk.content().equalsIgnoreCase("Dogs are loyal companions and great pets"));
+        assertThat(chunk.score(), equalTo(0f));
+
+        chunk = scoredChunks.get(2);
+        assertTrue(chunk.content().equalsIgnoreCase("The weather today is very sunny and warm"));
+        assertThat(chunk.score(), equalTo(0f));
+
+        // Null and Empty chunk input
+        scoredChunks = scorer.scoreChunks(List.of(), "puggles", maxResults);
+        assertTrue(scoredChunks.isEmpty());
+
+        scoredChunks = scorer.scoreChunks(CHUNKS, "", maxResults);
+        assertTrue(scoredChunks.isEmpty());
+
+        scoredChunks = scorer.scoreChunks(null, "puggles", maxResults);
+        assertTrue(scoredChunks.isEmpty());
+
+        scoredChunks = scorer.scoreChunks(CHUNKS, null, maxResults);
+        assertTrue(scoredChunks.isEmpty());
+    }
+
+}

+ 19 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java

@@ -52,6 +52,25 @@ public class RecursiveChunkingSettings implements ChunkingSettings {
         separators = in.readCollectionAsList(StreamInput::readString);
     }
 
+    @Override
+    public void validate() {
+        ValidationException validationException = new ValidationException();
+
+        if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) {
+            validationException.addValidationError(
+                ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT
+            );
+
+            if (separators != null && separators.isEmpty()) {
+                validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
+            }
+
+            if (validationException.validationErrors().isEmpty() == false) {
+                throw validationException;
+            }
+        }
+    }
+
     public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
         ValidationException validationException = new ValidationException();
 

+ 21 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

@@ -59,6 +59,27 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
         return maxChunkSize;
     }
 
+    @Override
+    public void validate() {
+        ValidationException validationException = new ValidationException();
+
+        if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) {
+            validationException.addValidationError(
+                ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT
+            );
+        }
+
+        if (sentenceOverlap > 1 || sentenceOverlap < 0) {
+            validationException.addValidationError(
+                ChunkingSettingsOptions.SENTENCE_OVERLAP + "[" + sentenceOverlap + "] must be either 0 or 1"
+            );
+        }
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+    }
+
     @Override
     public Map<String, Object> asMap() {
         return Map.of(

+ 21 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java

@@ -48,6 +48,27 @@ public class WordBoundaryChunkingSettings implements ChunkingSettings {
         overlap = in.readInt();
     }
 
+    @Override
+    public void validate() {
+        ValidationException validationException = new ValidationException();
+
+        if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) {
+            validationException.addValidationError(
+                ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT
+            );
+        }
+
+        if (overlap > maxChunkSize / 2) {
+            validationException.addValidationError(
+                ChunkingSettingsOptions.OVERLAP + "[" + overlap + "] must be less than or equal to half of max chunk size"
+            );
+        }
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+    }
+
     @Override
     public Map<String, Object> asMap() {
         return Map.of(

+ 100 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/ChunkScorerConfig.java

@@ -0,0 +1,100 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rank.textsimilarity;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
+import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+public class ChunkScorerConfig implements Writeable {
+
+    public final Integer size;
+    private final String inferenceText;
+    private final ChunkingSettings chunkingSettings;
+
+    public static final int DEFAULT_CHUNK_SIZE = 300;
+    public static final int DEFAULT_SIZE = 1;
+
+    public static ChunkingSettings createChunkingSettings(Integer chunkSize) {
+        int chunkSizeOrDefault = chunkSize != null ? chunkSize : DEFAULT_CHUNK_SIZE;
+        ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(chunkSizeOrDefault, 0);
+        chunkingSettings.validate();
+        return chunkingSettings;
+    }
+
+    public static ChunkingSettings chunkingSettingsFromMap(Map<String, Object> map) {
+
+        if (map == null || map.isEmpty()) {
+            return createChunkingSettings(DEFAULT_CHUNK_SIZE);
+        }
+
+        if (map.size() == 1 && map.containsKey("max_chunk_size")) {
+            return createChunkingSettings((Integer) map.get("max_chunk_size"));
+        }
+
+        return ChunkingSettingsBuilder.fromMap(map);
+    }
+
+    public ChunkScorerConfig(StreamInput in) throws IOException {
+        this.size = in.readOptionalVInt();
+        this.inferenceText = in.readString();
+        Map<String, Object> chunkingSettingsMap = in.readGenericMap();
+        this.chunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
+    }
+
+    public ChunkScorerConfig(Integer size, ChunkingSettings chunkingSettings) {
+        this(size, null, chunkingSettings);
+    }
+
+    public ChunkScorerConfig(Integer size, String inferenceText, ChunkingSettings chunkingSettings) {
+        this.size = size;
+        this.inferenceText = inferenceText;
+        this.chunkingSettings = chunkingSettings;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalVInt(size);
+        out.writeString(inferenceText);
+        out.writeGenericMap(chunkingSettings.asMap());
+    }
+
+    public Integer size() {
+        return size;
+    }
+
+    public String inferenceText() {
+        return inferenceText;
+    }
+
+    public ChunkingSettings chunkingSettings() {
+        return chunkingSettings;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        ChunkScorerConfig that = (ChunkScorerConfig) o;
+        return Objects.equals(size, that.size)
+            && Objects.equals(inferenceText, that.inferenceText)
+            && Objects.equals(chunkingSettings, that.chunkingSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(size, inferenceText, chunkingSettings);
+    }
+}

+ 0 - 88
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java

@@ -1,88 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.rank.textsimilarity;
-
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.index.query.QueryBuilder;
-
-import java.io.IOException;
-import java.util.Objects;
-
-public class SnippetConfig implements Writeable {
-
-    public final Integer numSnippets;
-    private final String inferenceText;
-    private final Integer tokenSizeLimit;
-    public final QueryBuilder snippetQueryBuilder;
-
-    public static final int DEFAULT_NUM_SNIPPETS = 1;
-
-    public SnippetConfig(StreamInput in) throws IOException {
-        this.numSnippets = in.readOptionalVInt();
-        this.inferenceText = in.readString();
-        this.tokenSizeLimit = in.readOptionalVInt();
-        this.snippetQueryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
-    }
-
-    public SnippetConfig(Integer numSnippets) {
-        this(numSnippets, null, null);
-    }
-
-    public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit) {
-        this(numSnippets, inferenceText, tokenSizeLimit, null);
-    }
-
-    public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit, QueryBuilder snippetQueryBuilder) {
-        this.numSnippets = numSnippets;
-        this.inferenceText = inferenceText;
-        this.tokenSizeLimit = tokenSizeLimit;
-        this.snippetQueryBuilder = snippetQueryBuilder;
-    }
-
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeOptionalVInt(numSnippets);
-        out.writeString(inferenceText);
-        out.writeOptionalVInt(tokenSizeLimit);
-        out.writeOptionalNamedWriteable(snippetQueryBuilder);
-    }
-
-    public Integer numSnippets() {
-        return numSnippets;
-    }
-
-    public String inferenceText() {
-        return inferenceText;
-    }
-
-    public Integer tokenSizeLimit() {
-        return tokenSizeLimit;
-    }
-
-    public QueryBuilder snippetQueryBuilder() {
-        return snippetQueryBuilder;
-    }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        SnippetConfig that = (SnippetConfig) o;
-        return Objects.equals(numSnippets, that.numSnippets)
-            && Objects.equals(inferenceText, that.inferenceText)
-            && Objects.equals(tokenSizeLimit, that.tokenSizeLimit)
-            && Objects.equals(snippetQueryBuilder, that.snippetQueryBuilder);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(numSnippets, inferenceText, tokenSizeLimit, snippetQueryBuilder);
-    }
-}

+ 15 - 73
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

@@ -15,9 +15,6 @@ import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.index.query.MatchQueryBuilder;
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.license.License;
 import org.elasticsearch.license.LicensedFeature;
 import org.elasticsearch.search.rank.RankBuilder;
@@ -33,12 +30,12 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.CHUNK_RESCORER_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.FAILURES_ALLOWED_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.FIELD_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD;
-import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD;
 
 /**
  * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call.
@@ -47,11 +44,6 @@ public class TextSimilarityRankBuilder extends RankBuilder {
 
     public static final String NAME = "text_similarity_reranker";
 
-    /**
-     * The default token size limit of the Elastic reranker is 512.
-     */
-    private static final int DEFAULT_TOKEN_SIZE_LIMIT = 512;
-
     public static final LicensedFeature.Momentary TEXT_SIMILARITY_RERANKER_FEATURE = LicensedFeature.momentary(
         null,
         "text-similarity-reranker",
@@ -65,7 +57,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
     private final String field;
     private final Float minScore;
     private final boolean failuresAllowed;
-    private final SnippetConfig snippetConfig;
+    private final ChunkScorerConfig chunkScorerConfig;
 
     public TextSimilarityRankBuilder(
         String field,
@@ -74,7 +66,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         int rankWindowSize,
         Float minScore,
         boolean failuresAllowed,
-        SnippetConfig snippetConfig
+        ChunkScorerConfig chunkScorerConfig
     ) {
         super(rankWindowSize);
         this.inferenceId = inferenceId;
@@ -82,7 +74,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         this.field = field;
         this.minScore = minScore;
         this.failuresAllowed = failuresAllowed;
-        this.snippetConfig = snippetConfig;
+        this.chunkScorerConfig = chunkScorerConfig;
     }
 
     public TextSimilarityRankBuilder(StreamInput in) throws IOException {
@@ -99,9 +91,9 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             this.failuresAllowed = false;
         }
         if (in.getTransportVersion().supports(RERANK_SNIPPETS)) {
-            this.snippetConfig = in.readOptionalWriteable(SnippetConfig::new);
+            this.chunkScorerConfig = in.readOptionalWriteable(ChunkScorerConfig::new);
         } else {
-            this.snippetConfig = null;
+            this.chunkScorerConfig = null;
         }
     }
 
@@ -127,7 +119,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             out.writeBoolean(failuresAllowed);
         }
         if (out.getTransportVersion().supports(RERANK_SNIPPETS)) {
-            out.writeOptionalWriteable(snippetConfig);
+            out.writeOptionalWriteable(chunkScorerConfig);
         }
     }
 
@@ -144,53 +136,9 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         if (failuresAllowed) {
             builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true);
         }
-        if (snippetConfig != null) {
-            builder.field(SNIPPETS_FIELD.getPreferredName(), snippetConfig);
-        }
-    }
-
-    @Override
-    public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
-        TextSimilarityRankBuilder rewritten = this;
-        if (snippetConfig != null) {
-            QueryBuilder snippetQueryBuilder = snippetConfig.snippetQueryBuilder();
-            if (snippetQueryBuilder == null) {
-                rewritten = new TextSimilarityRankBuilder(
-                    field,
-                    inferenceId,
-                    inferenceText,
-                    rankWindowSize(),
-                    minScore,
-                    failuresAllowed,
-                    new SnippetConfig(
-                        snippetConfig.numSnippets(),
-                        snippetConfig.inferenceText(),
-                        snippetConfig.tokenSizeLimit(),
-                        new MatchQueryBuilder(field, inferenceText)
-                    )
-                );
-            } else {
-                QueryBuilder rewrittenSnippetQueryBuilder = snippetQueryBuilder.rewrite(queryRewriteContext);
-                if (snippetQueryBuilder != rewrittenSnippetQueryBuilder) {
-                    rewritten = new TextSimilarityRankBuilder(
-                        field,
-                        inferenceId,
-                        inferenceText,
-                        rankWindowSize(),
-                        minScore,
-                        failuresAllowed,
-                        new SnippetConfig(
-                            snippetConfig.numSnippets(),
-                            snippetConfig.inferenceText(),
-                            snippetConfig.tokenSizeLimit(),
-                            rewrittenSnippetQueryBuilder
-                        )
-                    );
-                }
-            }
+        if (chunkScorerConfig != null) {
+            builder.field(CHUNK_RESCORER_FIELD.getPreferredName(), chunkScorerConfig);
         }
-
-        return rewritten;
     }
 
     @Override
@@ -237,7 +185,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
 
     @Override
     public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
-        return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, snippetConfig);
+        return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, chunkScorerConfig);
     }
 
     @Override
@@ -251,18 +199,12 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             inferenceText,
             minScore,
             failuresAllowed,
-            snippetConfig != null ? new SnippetConfig(snippetConfig.numSnippets, inferenceText, tokenSizeLimit(inferenceId)) : null
+            chunkScorerConfig != null
+                ? new ChunkScorerConfig(chunkScorerConfig.size, inferenceText, chunkScorerConfig.chunkingSettings())
+                : null
         );
     }
 
-    /**
-     * @return The token size limit to apply to this rerank context.
-     * TODO: This should be pulled from the inference endpoint when available, not hardcoded.
-     */
-    public static Integer tokenSizeLimit(String inferenceId) {
-        return DEFAULT_TOKEN_SIZE_LIMIT;
-    }
-
     public String field() {
         return field;
     }
@@ -291,12 +233,12 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             && Objects.equals(field, that.field)
             && Objects.equals(minScore, that.minScore)
             && failuresAllowed == that.failuresAllowed
-            && Objects.equals(snippetConfig, that.snippetConfig);
+            && Objects.equals(chunkScorerConfig, that.chunkScorerConfig);
     }
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, snippetConfig);
+        return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, chunkScorerConfig);
     }
 
     @Override

+ 6 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

@@ -40,7 +40,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
     protected final String inferenceId;
     protected final String inferenceText;
     protected final Float minScore;
-    protected final SnippetConfig snippetConfig;
+    protected final ChunkScorerConfig chunkScorerConfig;
 
     public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
         int size,
@@ -51,14 +51,14 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
         String inferenceText,
         Float minScore,
         boolean failuresAllowed,
-        @Nullable SnippetConfig snippetConfig
+        @Nullable ChunkScorerConfig chunkScorerConfig
     ) {
         super(size, from, rankWindowSize, failuresAllowed);
         this.client = client;
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.minScore = minScore;
-        this.snippetConfig = snippetConfig;
+        this.chunkScorerConfig = chunkScorerConfig;
     }
 
     @Override
@@ -80,8 +80,8 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
                 l.onResponse(originalScores);
             } else {
                 final float[] scores;
-                if (this.snippetConfig != null) {
-                    scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs);
+                if (this.chunkScorerConfig != null) {
+                    scores = extractScoresFromRankedChunks(rankedDocs, featureDocs);
                 } else {
                     scores = extractScoresFromRankedDocs(rankedDocs);
                 }
@@ -200,7 +200,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
         return scores;
     }
 
-    float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
+    float[] extractScoresFromRankedChunks(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
         float[] scores = new float[featureDocs.length];
         boolean[] hasScore = new boolean[featureDocs.length];
 

+ 37 - 30
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc;
 import org.elasticsearch.common.util.FeatureFlag;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -26,6 +27,7 @@ import org.elasticsearch.xcontent.XContentParser;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
@@ -50,8 +52,9 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
     public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
     public static final ParseField FIELD_FIELD = new ParseField("field");
     public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures");
-    public static final ParseField SNIPPETS_FIELD = new ParseField("snippets");
-    public static final ParseField NUM_SNIPPETS_FIELD = new ParseField("num_snippets");
+    public static final ParseField CHUNK_RESCORER_FIELD = new ParseField("chunk_rescorer");
+    public static final ParseField CHUNK_SIZE_FIELD = new ParseField("size");
+    public static final ParseField CHUNKING_SETTINGS_FIELD = new ParseField("chunking_settings");
 
     public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
         new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {
@@ -61,7 +64,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             String field = (String) args[3];
             int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
             boolean failuresAllowed = args[5] != null && (Boolean) args[5];
-            SnippetConfig snippets = (SnippetConfig) args[6];
+            ChunkScorerConfig chunkScorerConfig = (ChunkScorerConfig) args[6];
 
             return new TextSimilarityRankRetrieverBuilder(
                 retrieverBuilder,
@@ -70,18 +73,18 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
                 field,
                 rankWindowSize,
                 failuresAllowed,
-                snippets
+                chunkScorerConfig
             );
         });
 
-    private static final ConstructingObjectParser<SnippetConfig, RetrieverParserContext> SNIPPETS_PARSER = new ConstructingObjectParser<>(
-        SNIPPETS_FIELD.getPreferredName(),
-        true,
-        args -> {
-            Integer numSnippets = (Integer) args[0];
-            return new SnippetConfig(numSnippets);
-        }
-    );
+    private static final ConstructingObjectParser<ChunkScorerConfig, RetrieverParserContext> CHUNK_SCORER_PARSER =
+        new ConstructingObjectParser<>(CHUNK_RESCORER_FIELD.getPreferredName(), true, args -> {
+            Integer size = (Integer) args[0];
+            @SuppressWarnings("unchecked")
+            Map<String, Object> chunkingSettingsMap = (Map<String, Object>) args[1];
+            ChunkingSettings chunkingSettings = ChunkScorerConfig.chunkingSettingsFromMap(chunkingSettingsMap);
+            return new ChunkScorerConfig(size, chunkingSettings);
+        });
 
     static {
         PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
@@ -94,9 +97,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         PARSER.declareString(constructorArg(), FIELD_FIELD);
         PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
         PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD);
-        PARSER.declareObject(optionalConstructorArg(), SNIPPETS_PARSER, SNIPPETS_FIELD);
+        PARSER.declareObject(optionalConstructorArg(), CHUNK_SCORER_PARSER, CHUNK_RESCORER_FIELD);
         if (RERANK_SNIPPETS.isEnabled()) {
-            SNIPPETS_PARSER.declareInt(optionalConstructorArg(), NUM_SNIPPETS_FIELD);
+            CHUNK_SCORER_PARSER.declareInt(optionalConstructorArg(), CHUNK_SIZE_FIELD);
+            CHUNK_SCORER_PARSER.declareObjectOrNull(optionalConstructorArg(), (p, c) -> p.map(), null, CHUNKING_SETTINGS_FIELD);
         }
 
         RetrieverBuilder.declareBaseParserFields(PARSER);
@@ -117,7 +121,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
     private final String inferenceText;
     private final String field;
     private final boolean failuresAllowed;
-    private final SnippetConfig snippets;
+    private final ChunkScorerConfig chunkScorerConfig;
 
     public TextSimilarityRankRetrieverBuilder(
         RetrieverBuilder retrieverBuilder,
@@ -126,14 +130,14 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         String field,
         int rankWindowSize,
         boolean failuresAllowed,
-        SnippetConfig snippets
+        ChunkScorerConfig chunkScorerConfig
     ) {
         super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize);
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.field = field;
         this.failuresAllowed = failuresAllowed;
-        this.snippets = snippets;
+        this.chunkScorerConfig = chunkScorerConfig;
     }
 
     public TextSimilarityRankRetrieverBuilder(
@@ -146,14 +150,14 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         boolean failuresAllowed,
         String retrieverName,
         List<QueryBuilder> preFilterQueryBuilders,
-        SnippetConfig snippets
+        ChunkScorerConfig chunkScorerConfig
     ) {
         super(retrieverSource, rankWindowSize);
         if (retrieverSource.size() != 1) {
             throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever");
         }
-        if (snippets != null && snippets.numSnippets() != null && snippets.numSnippets() < 1) {
-            throw new IllegalArgumentException("num_snippets must be greater than 0, was: " + snippets.numSnippets());
+        if (chunkScorerConfig != null && chunkScorerConfig.size() != null && chunkScorerConfig.size() < 1) {
+            throw new IllegalArgumentException("size must be greater than 0, was: " + chunkScorerConfig.size());
         }
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
@@ -162,7 +166,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         this.failuresAllowed = failuresAllowed;
         this.retrieverName = retrieverName;
         this.preFilterQueryBuilders = preFilterQueryBuilders;
-        this.snippets = snippets;
+        this.chunkScorerConfig = chunkScorerConfig;
     }
 
     @Override
@@ -180,7 +184,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             failuresAllowed,
             retrieverName,
             newPreFilterQueryBuilders,
-            snippets
+            chunkScorerConfig
         );
     }
 
@@ -215,8 +219,8 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
                 rankWindowSize,
                 minScore,
                 failuresAllowed,
-                snippets != null
-                    ? new SnippetConfig(snippets.numSnippets, inferenceText, TextSimilarityRankBuilder.tokenSizeLimit(inferenceId))
+                chunkScorerConfig != null
+                    ? new ChunkScorerConfig(chunkScorerConfig.size, inferenceText, chunkScorerConfig.chunkingSettings())
                     : null
             )
         );
@@ -246,10 +250,13 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         if (failuresAllowed) {
             builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed);
         }
-        if (snippets != null) {
-            builder.startObject(SNIPPETS_FIELD.getPreferredName());
-            if (snippets.numSnippets() != null) {
-                builder.field(NUM_SNIPPETS_FIELD.getPreferredName(), snippets.numSnippets());
+        if (chunkScorerConfig != null) {
+            builder.startObject(CHUNK_RESCORER_FIELD.getPreferredName());
+            if (chunkScorerConfig.size() != null) {
+                builder.field(CHUNK_SIZE_FIELD.getPreferredName(), chunkScorerConfig.size());
+            }
+            if (chunkScorerConfig.chunkingSettings() != null) {
+                builder.field(CHUNKING_SETTINGS_FIELD.getPreferredName(), chunkScorerConfig.chunkingSettings().asMap());
             }
             builder.endObject();
         }
@@ -265,11 +272,11 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             && rankWindowSize == that.rankWindowSize
             && Objects.equals(minScore, that.minScore)
             && failuresAllowed == that.failuresAllowed
-            && Objects.equals(snippets, that.snippets);
+            && Objects.equals(chunkScorerConfig, that.chunkScorerConfig);
     }
 
     @Override
     public int doHashCode() {
-        return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, snippets);
+        return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, chunkScorerConfig);
     }
 }

+ 35 - 54
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java

@@ -8,38 +8,34 @@
 package org.elasticsearch.xpack.inference.rank.textsimilarity;
 
 import org.elasticsearch.common.document.DocumentField;
-import org.elasticsearch.common.logging.HeaderWarning;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
-import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
-import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
-import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext;
-import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.rank.RankShardResult;
 import org.elasticsearch.search.rank.feature.RankFeatureDoc;
 import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
 import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext;
-import org.elasticsearch.xcontent.Text;
+import org.elasticsearch.xpack.core.common.chunks.MemoryIndexChunkScorer;
+import org.elasticsearch.xpack.inference.chunking.Chunker;
+import org.elasticsearch.xpack.inference.chunking.ChunkerBuilder;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.List;
-import java.util.Map;
 
-import static org.elasticsearch.xpack.inference.rank.textsimilarity.SnippetConfig.DEFAULT_NUM_SNIPPETS;
+import static org.elasticsearch.xpack.inference.rank.textsimilarity.ChunkScorerConfig.DEFAULT_SIZE;
 
 public class TextSimilarityRerankingRankFeaturePhaseRankShardContext extends RerankingRankFeaturePhaseRankShardContext {
 
-    private final SnippetConfig snippetRankInput;
+    private final ChunkScorerConfig chunkScorerConfig;
+    private final ChunkingSettings chunkingSettings;
+    private final Chunker chunker;
 
-    // Rough approximation of token size vs. characters in highlight fragments.
-    // TODO: highlighter should be able to set fragment size by token not length
-    private static final int TOKEN_SIZE_LIMIT_MULTIPLIER = 5;
-
-    public TextSimilarityRerankingRankFeaturePhaseRankShardContext(String field, @Nullable SnippetConfig snippetRankInput) {
+    public TextSimilarityRerankingRankFeaturePhaseRankShardContext(String field, @Nullable ChunkScorerConfig chunkScorerConfig) {
         super(field);
-        this.snippetRankInput = snippetRankInput;
+        this.chunkScorerConfig = chunkScorerConfig;
+        chunkingSettings = chunkScorerConfig != null ? chunkScorerConfig.chunkingSettings() : null;
+        chunker = chunkingSettings != null ? ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()) : null;
     }
 
     @Override
@@ -49,49 +45,34 @@ public class TextSimilarityRerankingRankFeaturePhaseRankShardContext extends Rer
             rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
             SearchHit hit = hits.getHits()[i];
             DocumentField docField = hit.field(field);
-            if (snippetRankInput == null && docField != null) {
-                rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
-            } else {
-                Map<String, HighlightField> highlightFields = hit.getHighlightFields();
-                if (highlightFields != null && highlightFields.containsKey(field) && highlightFields.get(field).fragments().length > 0) {
-                    List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
-                    rankFeatureDocs[i].featureData(snippets);
-                } else if (docField != null) {
-                    // If we did not get highlighting results, backfill with the doc field value
-                    // but pass in a warning because we are not reranking on snippets only
+            if (docField != null) {
+                if (chunkScorerConfig != null) {
+                    int size = chunkScorerConfig.size() != null ? chunkScorerConfig.size() : DEFAULT_SIZE;
+                    List<Chunker.ChunkOffset> chunkOffsets = chunker.chunk(docField.getValue().toString(), chunkingSettings);
+                    List<String> chunks = chunkOffsets.stream()
+                        .map(offset -> { return docField.getValue().toString().substring(offset.start(), offset.end()); })
+                        .toList();
+
+                    List<String> bestChunks;
+                    try {
+                        MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer();
+                        List<MemoryIndexChunkScorer.ScoredChunk> scoredChunks = scorer.scoreChunks(
+                            chunks,
+                            chunkScorerConfig.inferenceText(),
+                            size
+                        );
+                        bestChunks = scoredChunks.stream().map(MemoryIndexChunkScorer.ScoredChunk::content).limit(size).toList();
+                    } catch (IOException e) {
+                        throw new IllegalStateException("Could not generate chunks for input to reranker", e);
+                    }
+                    rankFeatureDocs[i].featureData(bestChunks);
+
+                } else {
                     rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
-                    HeaderWarning.addWarning(
-                        "Reranking on snippets requested, but no snippets were found for field [" + field + "]. Using field value instead."
-                    );
                 }
             }
         }
         return new RankFeatureShardResult(rankFeatureDocs);
     }
 
-    @Override
-    public void prepareForFetch(SearchContext context) {
-        if (snippetRankInput != null) {
-            try {
-                HighlightBuilder highlightBuilder = new HighlightBuilder();
-                highlightBuilder.highlightQuery(snippetRankInput.snippetQueryBuilder());
-                // Stripping pre/post tags as they're not useful for snippet creation
-                highlightBuilder.field(field).preTags("").postTags("");
-                // Return highest scoring fragments
-                highlightBuilder.order(HighlightBuilder.Order.SCORE);
-                int numSnippets = snippetRankInput.numSnippets() != null ? snippetRankInput.numSnippets() : DEFAULT_NUM_SNIPPETS;
-                highlightBuilder.numOfFragments(numSnippets);
-                // Rely on the model to determine the fragment size
-                int tokenSizeLimit = snippetRankInput.tokenSizeLimit();
-                int fragmentSize = tokenSizeLimit * TOKEN_SIZE_LIMIT_MULTIPLIER;
-                highlightBuilder.fragmentSize(fragmentSize);
-                highlightBuilder.noMatchSize(fragmentSize);
-                SearchHighlightContext searchHighlightContext = highlightBuilder.build(context.getSearchExecutionContext());
-                context.highlight(searchHighlightContext);
-            } catch (IOException e) {
-                throw new RuntimeException("Failed to generate snippet request", e);
-            }
-        }
-    }
-
 }

+ 8 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java

@@ -39,7 +39,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
         null
     );
 
-    TextSimilarityRankFeaturePhaseRankCoordinatorContext withSnippets = new TextSimilarityRankFeaturePhaseRankCoordinatorContext(
+    TextSimilarityRankFeaturePhaseRankCoordinatorContext withChunks = new TextSimilarityRankFeaturePhaseRankCoordinatorContext(
         10,
         0,
         100,
@@ -48,7 +48,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
         "some query",
         0.0f,
         false,
-        new SnippetConfig(2, "some query", 10)
+        new ChunkScorerConfig(2, "some query", null)
     );
 
     public void testComputeScores() {
@@ -87,7 +87,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
         assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, scores, 0.0f);
     }
 
-    public void testExtractScoresFromSingleSnippets() {
+    public void testExtractScoresFromSingleChunk() {
 
         List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
             new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"),
@@ -99,12 +99,12 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
             createRankFeatureDoc(1, 3.0f, 1, List.of("text 2")),
             createRankFeatureDoc(2, 2.0f, 0, List.of("text 3")) };
 
-        float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs);
-        // Returned cores are from the snippet, not the whole text
+        float[] scores = withChunks.extractScoresFromRankedChunks(rankedDocs, featureDocs);
+        // Returned cores are from the chunk, not the whole text
         assertArrayEquals(new float[] { 1.0f, 2.5f, 1.5f }, scores, 0.0f);
     }
 
-    public void testExtractScoresFromMultipleSnippets() {
+    public void testExtractScoresFromMultipleChunks() {
 
         List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
             new RankedDocsResults.RankedDoc(0, 1.0f, "this is text 1"),
@@ -119,8 +119,8 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
             createRankFeatureDoc(1, 3.0f, 1, List.of("yet more text", "this is text 2")),
             createRankFeatureDoc(2, 2.0f, 0, List.of("this is text 3", "oh look, more text")) };
 
-        float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs);
-        // Returned scores are from the best-ranking snippet, not the whole text
+        float[] scores = withChunks.extractScoresFromRankedChunks(rankedDocs, featureDocs);
+        // Returned scores are from the best-ranking chunk, not the whole text
         assertArrayEquals(new float[] { 2.5f, 3.0f, 2.0f }, scores, 0.0f);
     }
 

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java

@@ -177,9 +177,9 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin {
             Float minScore,
             boolean failuresAllowed,
             String throwingType,
-            SnippetConfig snippetConfig
+            ChunkScorerConfig chunkScorerConfig
         ) {
-            super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, snippetConfig);
+            super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, chunkScorerConfig);
             this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType);
         }
 

+ 175 - 33
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

@@ -515,14 +515,14 @@ setup:
 
 
 ---
-"Text similarity reranker specifying number of snippets must be > 0":
+"Text similarity reranker specifying number of rescore_chunks must be > 0":
 
   - requires:
       cluster_features: "text_similarity_reranker_snippets"
-      reason: snippets introduced in 9.2.0
+      reason: rescore_chunks introduced in 9.2.0
 
   - do:
-      catch: /num_snippets must be greater than 0/
+      catch: /size must be greater than 0/
       search:
         index: test-index
         body:
@@ -538,18 +538,18 @@ setup:
               inference_id: my-rerank-model
               inference_text: "How often does the moon hide the sun?"
               field: inference_text_field
-              snippets:
-                num_snippets: 0
+              chunk_rescorer:
+                size: 0
           size: 10
 
   - match: { status: 400 }
 
 ---
-"Reranking based on snippets":
+"Reranking based on rescore_chunks":
 
   - requires:
       cluster_features: "text_similarity_reranker_snippets"
-      reason: snippets introduced in 9.2.0
+      reason: rescore_chunks introduced in 9.2.0
 
   - do:
       search:
@@ -569,8 +569,8 @@ setup:
               inference_id: my-rerank-model
               inference_text: "How often does the moon hide the sun?"
               field: text
-              snippets:
-                num_snippets: 2
+              chunk_rescorer:
+                size: 2
           size: 10
 
   - match: { hits.total.value: 2 }
@@ -580,11 +580,11 @@ setup:
   - match: { hits.hits.1._id: "doc_2" }
 
 ---
-"Reranking based on snippets using defaults":
+"Reranking based on rescore_chunks using defaults":
 
   - requires:
       cluster_features: "text_similarity_reranker_snippets"
-      reason: snippets introduced in 9.2.0
+      reason: rescore_chunks introduced in 9.2.0
 
   - do:
       search:
@@ -603,7 +603,7 @@ setup:
               inference_id: my-rerank-model
               inference_text: "How often does the moon hide the sun?"
               field: text
-              snippets: { }
+              chunk_rescorer: { }
           size: 10
 
   - match: { hits.total.value: 2 }
@@ -613,11 +613,11 @@ setup:
   - match: { hits.hits.1._id: "doc_2" }
 
 ---
-"Reranking based on snippets on a semantic_text field":
+"Reranking based on rescore_chunks on a semantic_text field":
 
   - requires:
       cluster_features: "text_similarity_reranker_snippets"
-      reason: snippets introduced in 9.2.0
+      reason: rescore_chunks introduced in 9.2.0
 
   - do:
       search:
@@ -637,8 +637,8 @@ setup:
               inference_id: my-rerank-model
               inference_text: "how often does the moon hide the sun?"
               field: semantic_text_field
-              snippets:
-                num_snippets: 2
+              chunk_rescorer:
+                size: 2
           size: 10
 
   - match: { hits.total.value: 2 }
@@ -648,11 +648,11 @@ setup:
   - match: { hits.hits.1._id: "doc_2" }
 
 ---
-"Reranking based on snippets on a semantic_text field using defaults":
+"Reranking based on rescore_chunks on a semantic_text field using defaults":
 
   - requires:
       cluster_features: "text_similarity_reranker_snippets"
-      reason: snippets introduced in 9.2.0
+      reason: rescore_chunks introduced in 9.2.0
 
   - do:
       search:
@@ -672,7 +672,7 @@ setup:
               inference_id: my-rerank-model
               inference_text: "how often does the moon hide the sun?"
               field: semantic_text_field
-              snippets: { }
+              chunk_rescorer: { }
           size: 10
 
   - match: { hits.total.value: 2 }
@@ -682,38 +682,180 @@ setup:
   - match: { hits.hits.1._id: "doc_2" }
 
 ---
-"Reranking based on snippets when highlighter doesn't return results":
+"Reranking based on rescore_chunks on a semantic_text field specifying chunking settings":
 
   - requires:
-      test_runner_features: allowed_warnings
       cluster_features: "text_similarity_reranker_snippets"
-      reason: snippets introduced in 9.2.0
+      reason: rescore_chunks introduced in 9.2.0
 
   - do:
-      allowed_warnings:
-        - "Reranking on snippets requested, but no snippets were found for field [inference_text_field]. Using field value instead."
       search:
         index: test-index
         body:
           track_total_hits: true
-          fields: [ "text", "topic" ]
+          fields: [ "text", "semantic_text_field", "topic" ]
           retriever:
             text_similarity_reranker:
               retriever:
                 standard:
                   query:
-                    term:
-                      topic: "science"
+                    match:
+                      topic:
+                        query: "science"
               rank_window_size: 10
               inference_id: my-rerank-model
-              inference_text: "How often does the moon hide the sun?"
-              field: inference_text_field
-              snippets:
-                num_snippets: 2
+              inference_text: "how often does the moon hide the sun?"
+              field: semantic_text_field
+              chunk_rescorer:
+                chunking_settings:
+                  strategy: sentence
+                  max_chunk_size: 20
+                  sentence_overlap: 0
           size: 10
 
   - match: { hits.total.value: 2 }
   - length: { hits.hits: 2 }
 
-  - match: { hits.hits.0._id: "doc_2" }
-  - match: { hits.hits.1._id: "doc_1" }
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+---
+"Reranking based on rescore_chunks on a semantic_text field specifying chunking settings requires valid chunking settings":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: rescore_chunks introduced in 9.2.0
+
+  - do:
+      catch: /Invalid value/
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "semantic_text_field", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "how often does the moon hide the sun?"
+              field: semantic_text_field
+              chunk_rescorer:
+                chunk_size: 20
+                chunking_settings:
+                  strategy: sentence
+                  max_chunk_size: 10
+                  sentence_overlap: 20
+          size: 10
+
+---
+"Reranking based on rescore_chunks on a semantic_text field specifying chunk size":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: rescore_chunks introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "semantic_text_field", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "how often does the moon hide the sun?"
+              field: semantic_text_field
+              chunk_rescorer:
+                chunk_size: 20
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+---
+"Reranking based on chunk_rescorer specifying only max chunk size will default remaining chunking settings":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: rescore_chunks introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "semantic_text_field", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "how often does the moon hide the sun?"
+              field: semantic_text_field
+              chunk_rescorer:
+                chunk_rescorer: 20
+                chunking_settings:
+                  max_chunk_size: 20
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+
+---
+"Reranking based on chunk_rescorer will send in first chunk if no text matches found":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: rescore_chunks introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "semantic_text_field", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "iamanonsensefieldthatshouldreturnnoresults"
+              field: semantic_text_field
+              chunk_rescorer: { }
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }