Ver código fonte

Add l2_norm normalization support to linear retriever (#128504)

* New l2 normalizer added

* L2 score normaliser is registered

* test case added to the yaml

* Documentation added

* Resolved checkstyle issues

* Update docs/changelog/128504.yaml

* Update docs/reference/elasticsearch/rest-apis/retrievers.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Score 0 test case added to check for corner cases

* Edited the markdown doc description

* Pruned the comment

* Renamed the variable

* Added comment to the class

* Unit tests added

* Spotless and checkstyle fixed

* Fixed build failure

* Fixed the forbidden test

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Mridula 4 meses atrás
pai
commit
81fba27b6b

+ 5 - 0
docs/changelog/128504.yaml

@@ -0,0 +1,5 @@
+pr: 128504
+summary: Add l2_norm normalization support to linear retriever
+area: Relevance
+type: enhancement
+issues: []

+ 2 - 1
docs/reference/elasticsearch/rest-apis/retrievers.md

@@ -276,7 +276,7 @@ Each entry specifies the following parameters:
 `normalizer`
 :   (Optional, String)
 
-    Specifies how we will normalize the retriever’s scores, before applying the specified `weight`. Available values are: `minmax`, and `none`. Defaults to `none`.
+    - Specifies how we will normalize the retriever’s scores, before applying the specified `weight`. Available values are: `minmax`, `l2_norm`, and `none`. Defaults to `none`.
 
     * `none`
     * `minmax` : A `MinMaxScoreNormalizer` that normalizes scores based on the following formula
@@ -285,6 +285,7 @@ Each entry specifies the following parameters:
         score = (score - min) / (max - min)
         ```
 
+    * `l2_norm` : An `L2ScoreNormalizer` that normalizes scores using the L2 norm of the score values.
 
 See also [this hybrid search example](docs-content://solutions/search/retrievers-examples.md#retrievers-examples-linear-retriever) using a linear retriever on how to independently configure and apply normalizers to retrievers.
 

+ 63 - 0
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/L2ScoreNormalizer.java

@@ -0,0 +1,63 @@
+
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.rank.linear;
+
+import org.apache.lucene.search.ScoreDoc;
+
+/**
+ * A score normalizer that applies L2 normalization to a set of scores.
+ * <p>
+ * This normalizer scales the scores so that the L2 norm of the score vector is 1,
+ * if possible. If all scores are zero or NaN, normalization is skipped and the original scores are returned.
+ * </p>
+ */
+public class L2ScoreNormalizer extends ScoreNormalizer {
+
+    public static final L2ScoreNormalizer INSTANCE = new L2ScoreNormalizer();
+
+    public static final String NAME = "l2_norm";
+
+    private static final float EPSILON = 1e-6f;
+
+    public L2ScoreNormalizer() {}
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public ScoreDoc[] normalizeScores(ScoreDoc[] docs) {
+        if (docs.length == 0) {
+            return docs;
+        }
+        double sumOfSquares = 0.0;
+        boolean atLeastOneValidScore = false;
+        for (ScoreDoc doc : docs) {
+            if (Float.isNaN(doc.score) == false) {
+                atLeastOneValidScore = true;
+                sumOfSquares += doc.score * doc.score;
+            }
+        }
+        if (atLeastOneValidScore == false) {
+            // No valid scores to normalize
+            return docs;
+        }
+        double norm = Math.sqrt(sumOfSquares);
+        if (norm < EPSILON) {
+            return docs;
+        }
+        ScoreDoc[] scoreDocs = new ScoreDoc[docs.length];
+        for (int i = 0; i < docs.length; i++) {
+            float score = (float) (docs[i].score / norm);
+            scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex);
+        }
+        return scoreDocs;
+    }
+}

+ 3 - 0
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java

@@ -17,6 +17,9 @@ public abstract class ScoreNormalizer {
     public static ScoreNormalizer valueOf(String normalizer) {
         if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
             return MinMaxScoreNormalizer.INSTANCE;
+        } else if (L2ScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
+            return L2ScoreNormalizer.INSTANCE;
+
         } else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) {
             return IdentityScoreNormalizer.INSTANCE;
 

+ 54 - 0
x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/L2ScoreNormalizerTests.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.rank.linear;
+
+import org.apache.lucene.search.ScoreDoc;
+import org.elasticsearch.test.ESTestCase;
+
+public class L2ScoreNormalizerTests extends ESTestCase {
+
+    public void testNormalizeTypicalVector() {
+        ScoreDoc[] docs = { new ScoreDoc(1, 3.0f, 0), new ScoreDoc(2, 4.0f, 0) };
+        ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
+        assertEquals(0.6f, normalized[0].score, 1e-5);
+        assertEquals(0.8f, normalized[1].score, 1e-5);
+    }
+
+    public void testAllZeros() {
+        ScoreDoc[] docs = { new ScoreDoc(1, 0.0f, 0), new ScoreDoc(2, 0.0f, 0) };
+        ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
+        assertEquals(0.0f, normalized[0].score, 0.0f);
+        assertEquals(0.0f, normalized[1].score, 0.0f);
+    }
+
+    public void testAllNaN() {
+        ScoreDoc[] docs = { new ScoreDoc(1, Float.NaN, 0), new ScoreDoc(2, Float.NaN, 0) };
+        ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
+        assertTrue(Float.isNaN(normalized[0].score));
+        assertTrue(Float.isNaN(normalized[1].score));
+    }
+
+    public void testMixedZeroAndNaN() {
+        ScoreDoc[] docs = { new ScoreDoc(1, 0.0f, 0), new ScoreDoc(2, Float.NaN, 0) };
+        ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
+        assertEquals(0.0f, normalized[0].score, 0.0f);
+        assertTrue(Float.isNaN(normalized[1].score));
+    }
+
+    public void testSingleElement() {
+        ScoreDoc[] docs = { new ScoreDoc(1, 42.0f, 0) };
+        ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
+        assertEquals(1.0f, normalized[0].score, 1e-5);
+    }
+
+    public void testEmptyArray() {
+        ScoreDoc[] docs = {};
+        ScoreDoc[] normalized = L2ScoreNormalizer.INSTANCE.normalizeScores(docs);
+        assertEquals(0, normalized.length);
+    }
+}

+ 87 - 0
x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml

@@ -265,6 +265,93 @@ setup:
   - match: { hits.hits.3._id: "3" }
   - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } }
 
+---
+"should normalize initial scores with l2_norm":
+  - do:
+      search:
+        index: test
+        body:
+          retriever:
+            linear:
+              retrievers: [
+                {
+                  retriever: {
+                    standard: {
+                      query: {
+                        bool: {
+                          should: [
+                            { constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 3.0 } },
+                            { constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 4.0 } }
+                          ]
+                        }
+                      }
+                    }
+                  },
+                  weight: 10.0,
+                  normalizer: "l2_norm"
+                },
+                {
+                  retriever: {
+                    standard: {
+                      query: {
+                        bool: {
+                          should: [
+                            { constant_score: { filter: { term: { keyword: { value: "three" } } }, boost: 6.0 } },
+                            { constant_score: { filter: { term: { keyword: { value: "four" } } }, boost: 8.0 } }
+                          ]
+                        }
+                      }
+                    }
+                  },
+                  weight: 2.0,
+                  normalizer: "l2_norm"
+                }
+              ]
+
+  - match: { hits.total.value: 4 }
+  - match: { hits.hits.0._id: "2" }
+  - match: { hits.hits.0._score: 8.0 }
+  - match: { hits.hits.1._id: "1" }
+  - match: { hits.hits.1._score: 6.0 }
+  - match: { hits.hits.2._id: "4" }
+  - close_to: { hits.hits.2._score: { value: 1.6, error: 0.001 } }
+  - match: { hits.hits.3._id: "3" }
+  - match: { hits.hits.3._score: 1.2 }
+
+---
+"should handle all zero scores in normalization":
+  - do:
+      search:
+        index: test
+        body:
+          retriever:
+            linear:
+              retrievers: [
+                {
+                  retriever: {
+                    standard: {
+                      query: {
+                        bool: {
+                          should: [
+                            { constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 0.0 } },
+                            { constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 0.0 } },
+                            { constant_score: { filter: { term: { keyword: { value: "three" } } }, boost: 0.0 } },
+                            { constant_score: { filter: { term: { keyword: { value: "four" } } }, boost: 0.0 } }
+                          ]
+                        }
+                      }
+                    }
+                  },
+                  weight: 1.0,
+                  normalizer: "l2_norm"
+                }
+              ]
+  - match: { hits.total.value: 4 }
+  - close_to: { hits.hits.0._score: { value: 0.0, error: 0.0001 } }
+  - close_to: { hits.hits.1._score: { value: 0.0, error: 0.0001 } }
+  - close_to: { hits.hits.2._score: { value: 0.0, error: 0.0001 } }
+  - close_to: { hits.hits.3._score: { value: 0.0, error: 0.0001 } }
+
 ---
 "should throw on unknown normalizer":
   - do: