浏览代码

Normalize negative scores for text_similarity_reranker retriever (#120930) (#121050)

Panagiotis Bailis 8 月之前
父节点
当前提交
751c1c52d3
共有 13 个文件被更改,包括 105 次插入25 次删除
  1. 6 0
      docs/changelog/120930.yaml
  2. 17 0
      docs/reference/search/retriever.asciidoc
  3. 10 1
      server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java
  4. 12 0
      server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java
  5. 13 3
      test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java
  6. 7 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java
  7. 5 1
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
  8. 19 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java
  9. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java
  10. 5 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java
  11. 9 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java
  12. 1 0
      x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java
  13. 0 6
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

+ 6 - 0
docs/changelog/120930.yaml

@@ -0,0 +1,6 @@
+pr: 120930
+summary: Normalize negative scores for `text_similarity_reranker` retriever
+area: Ranking
+type: bug
+issues:
+ - 120201

+ 17 - 0
docs/reference/search/retriever.asciidoc

@@ -523,6 +523,23 @@ You have the following options:
 ** Then set up an <<inference-example-eland,{es} service inference endpoint>> with the `rerank` task type.
 ** Refer to the <<text-similarity-reranker-retriever-example-eland,example>> on this page for a step-by-step guide.
 
+[IMPORTANT]
+====
+Scores from the re-ranking process are normalized using the following formula before returned to the user,
+to avoid having negative scores.
+[source,text]
+----
+score = max(score, 0) + min(exp(score), 1)
+----
+Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity).
+To revert back if needed, one can use:
+[source, text]
+----
+score = score - 1, if score >= 0
+score = ln(score), if score < 0
+----
+====
+
 ===== Parameters
 
 `retriever`::

+ 10 - 1
server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java

@@ -58,6 +58,11 @@ public class RankDocsQuery extends Query {
             this.queryNames = queryNames;
             this.segmentStarts = segmentStarts;
             this.contextIdentity = contextIdentity;
+            for (RankDoc doc : docs) {
+                if (false == doc.score >= 0) {
+                    throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?");
+                }
+            }
         }
 
         @Override
@@ -161,7 +166,11 @@ public class RankDocsQuery extends Query {
 
                         @Override
                         public float score() {
-                            return docs[upTo].score;
+                            // We could still end up with a valid 0 score for a RankDoc
+                            // so here we want to differentiate between this and all the tailQuery matches
+                            // that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for
+                            // RankDoc matches.
+                            return Math.max(docs[upTo].score, Float.MIN_VALUE);
                         }
 
                         @Override

+ 12 - 0
server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java

@@ -251,4 +251,16 @@ public class RankDocsQueryBuilderTests extends AbstractQueryTestCase<RankDocsQue
     public void testValidOutput() throws IOException {
         // no-op since RankDocsQueryBuilder is an internal only API
     }
+
+    public void shouldThrowForNegativeScores() throws IOException {
+        try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
+            iw.addDocument(new Document());
+            try (IndexReader reader = iw.getReader()) {
+                SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
+                RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
+                IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
+                assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
+            }
+        }
+    }
 }

+ 13 - 3
test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java

@@ -56,6 +56,10 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
 
     protected abstract Collection<Class<? extends Plugin>> pluginsNeeded();
 
+    protected boolean shouldCheckScores() {
+        return true;
+    }
+
     @Override
     protected Collection<Class<? extends Plugin>> nodePlugins() {
         return pluginsNeeded();
@@ -95,9 +99,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
                 int rank = 1;
                 for (SearchHit searchHit : response.getHits().getHits()) {
                     assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
-                    assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
                     assertThat(searchHit, hasRank(rank));
                     assertNotNull(searchHit.getFields().get(searchField));
+                    if (shouldCheckScores()) {
+                        assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
+                    }
                     rank++;
                 }
             }
@@ -140,9 +146,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
                 int rank = 3;
                 for (SearchHit searchHit : response.getHits().getHits()) {
                     assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
-                    assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
                     assertThat(searchHit, hasRank(rank));
                     assertNotNull(searchHit.getFields().get(searchField));
+                    if (shouldCheckScores()) {
+                        assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
+                    }
                     rank++;
                 }
             }
@@ -222,9 +230,11 @@ public abstract class AbstractRerankerIT extends ESIntegTestCase {
                 int rank = 1;
                 for (SearchHit searchHit : response.getHits().getHits()) {
                     assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
-                    assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
                     assertThat(searchHit, hasRank(rank));
                     assertNotNull(searchHit.getFields().get(searchField));
+                    if (shouldCheckScores()) {
+                        assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
+                    }
                     rank++;
                 }
             }

+ 7 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java

@@ -26,9 +26,16 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Random;
 
 public abstract class AbstractTestInferenceService implements InferenceService {
 
+    protected static final Random random = new Random(
+        System.getProperty("tests.seed") == null
+            ? System.currentTimeMillis()
+            : Long.parseUnsignedLong(System.getProperty("tests.seed").split(":")[0], 16)
+    );
+
     protected static int stringWeight(String input, int position) {
         int hashCode = input.hashCode();
         if (hashCode < 0) {

+ 5 - 1
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

@@ -42,6 +42,7 @@ import java.util.List;
 import java.util.Map;
 
 public class TestRerankingServiceExtension implements InferenceServiceExtension {
+
     @Override
     public List<Factory> getInferenceServiceFactories() {
         return List.of(TestInferenceService::new);
@@ -149,9 +150,12 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
         private RankedDocsResults makeResults(List<String> input) {
             List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
             int totalResults = input.size();
+            float minScore = random.nextFloat(-1f, 1f);
             float resultDiff = 0.2f;
             for (int i = 0; i < input.size(); i++) {
-                results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i)));
+                results.add(
+                    new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i))
+                );
             }
             return new RankedDocsResults(results);
         }

+ 19 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

@@ -20,8 +20,8 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
 
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 
@@ -130,10 +130,15 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
      */
     @Override
     protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
-        return Arrays.stream(originalDocs)
-            .filter(doc -> minScore == null || doc.score >= minScore)
-            .sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
-            .toArray(RankFeatureDoc[]::new);
+        List<RankFeatureDoc> docs = new ArrayList<>();
+        for (RankFeatureDoc doc : originalDocs) {
+            if (minScore == null || doc.score >= minScore) {
+                doc.score = normalizeScore(doc.score);
+                docs.add(doc);
+            }
+        }
+        docs.sort(RankFeatureDoc::compareTo);
+        return docs.toArray(new RankFeatureDoc[0]);
     }
 
     protected InferenceAction.Request generateRequest(List<String> docFeatures) {
@@ -154,7 +159,15 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
         for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
             scores[rankedDoc.index()] = rankedDoc.relevanceScore();
         }
-
         return scores;
     }
+
+    private static float normalizeScore(float score) {
+        // As some models might produce negative scores, we want to ensure that all scores will be positive
+        // so we will make use of the following normalization formula:
+        // score = max(score, 0) + min(exp(score), 1)
+        // this will ensure that all positive scores lie in the [1, inf) range,
+        // while negative values (and 0) will be shifted to (0, 1]
+        return Math.max(score, 0) + Math.min((float) Math.exp(score), 1);
+    }
 }

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -160,6 +160,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
         for (int i = 0; i < scoreDocs.length; i++) {
             ScoreDoc scoreDoc = scoreDocs[i];
+            assert scoreDoc.score >= 0;
             if (explain) {
                 textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
                     scoreDoc.doc,

+ 5 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java

@@ -50,4 +50,9 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT {
     public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception {
         // no-op
     }
+
+    @Override
+    protected boolean shouldCheckScores() {
+        return false;
+    }
 }

+ 9 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java

@@ -131,11 +131,12 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
                 // Verify order, rank and score of results
                 SearchHit[] hits = response.getHits().getHits();
                 assertEquals(5, hits.length);
-                assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
-                assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
-                assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
-                assertHitHasRankScoreAndText(hits[3], 4, 1.0f, "1");
-                assertHitHasRankScoreAndText(hits[4], 5, 0.0f, "0");
+                // we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1
+                assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
+                assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
+                assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
+                assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1");
+                assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0");
             }
         );
     }
@@ -150,9 +151,9 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
                 // Verify order, rank and score of results
                 SearchHit[] hits = response.getHits().getHits();
                 assertEquals(3, hits.length);
-                assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
-                assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
-                assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
+                assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
+                assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
+                assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
             }
         );
     }

+ 1 - 0
x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java

@@ -20,6 +20,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase {
 
     @ClassRule
     public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
+        .systemProperty("tests.seed", System.getProperty("tests.seed"))
         .setting("xpack.security.enabled", "false")
         .setting("xpack.security.http.ssl.enabled", "false")
         .setting("xpack.license.self_generated.type", "trial")

+ 0 - 6
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

@@ -91,10 +91,7 @@ setup:
   - length: { hits.hits: 2 }
 
   - match: { hits.hits.0._id: "doc_2" }
-  - close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } }
-
   - match: { hits.hits.1._id: "doc_1" }
-  - close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } }
 
 ---
 "Simple text similarity rank retriever and filtering":
@@ -125,8 +122,6 @@ setup:
   - length: { hits.hits: 1 }
 
   - match: { hits.hits.0._id: "doc_1" }
-  - close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } }
-
 
 ---
 "Text similarity reranking fails if the inference ID does not exist":
@@ -213,7 +208,6 @@ setup:
   - contains: { hits.hits: { _id: "doc_2" } }
   - contains: { hits.hits: { _id: "doc_1" } }
 
-  - close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } }
   - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" }
   - match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" }