Browse Source

Correct boost calculation in script_score query (#52478)

Before boost in script_score query was wrongly applied only to the subquery.
This commit makes sure that the boost is applied to the whole score
that comes out of script.

Closes #48465
Mayya Sharipova 5 years ago
parent
commit
556ee9a719

+ 4 - 1
docs/reference/query-dsl/script-score-query.asciidoc

@@ -48,9 +48,12 @@ scores be positive or `0`.
 --
 
 `min_score`::
-(Optional, float) Documents with a <<relevance-scores,relevance score>> lower
+(Optional, float) Documents with a score lower
 than this floating point number are excluded from the search results.
 
+`boost`::
+(Optional, float) Documents' scores produced by `script` are
+multiplied by `boost` to produce final documents' scores. Defaults to `1.0`.
 
 [[script-score-query-notes]]
 ==== Notes

+ 86 - 0
modules/lang-painless/src/test/resources/rest-api-spec/test/painless/110_script_score_boost.yml

@@ -0,0 +1,86 @@
+# Integration tests for ScriptScoreQuery using Painless
+setup:
+  - skip:
+      version: " - 7.9.99"
+      reason: "boost was corrected in script_score query from 8.0"
+  - do:
+      indices.create:
+        index: test_index
+        body:
+          settings:
+            index:
+              number_of_shards: 1
+              number_of_replicas: 0
+          mappings:
+            properties:
+              k:
+                type: keyword
+              i:
+                type: integer
+
+  - do:
+      bulk:
+        index: test_index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"k": "k", "i" : 1}'
+          - '{"index": {"_id": "2"}}'
+          - '{"k": "kk", "i" : 2}'
+          - '{"index": {"_id": "3"}}'
+          - '{"k": "kkk", "i" : 3}'
+---
+"Boost script_score":
+    - do:
+        search:
+          index: test_index
+          body:
+            query:
+              script_score:
+                query: {match_all: {}}
+                script:
+                  source: "doc['i'].value * _score"
+                boost: 10
+
+    - match: { hits.total.value: 3 }
+    - match: { hits.hits.0._score: 30 }
+    - match: { hits.hits.1._score: 20 }
+    - match: { hits.hits.2._score: 10 }
+
+---
+"Boost script_score and boost internal query":
+  - do:
+      search:
+        index: test_index
+        body:
+          query:
+            script_score:
+              query: {match_all: {boost: 5}}
+              script:
+                source: "doc['i'].value * _score"
+              boost: 10
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._score: 150 }
+  - match: { hits.hits.1._score: 100 }
+  - match: { hits.hits.2._score: 50 }
+
+---
+"Boost script_score with explain":
+  - do:
+      search:
+        index: test_index
+        body:
+          explain: true
+          query:
+            script_score:
+              query: {term: {"k": "kkk"}}
+              script:
+                source: "doc['i'].value"
+              boost: 10
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._score: 30 }
+  - match: { hits.hits.0._explanation.value: 30 }
+  - match: { hits.hits.0._explanation.details.0.description: "boost" }
+  - match: { hits.hits.0._explanation.details.0.value: 10}

+ 29 - 20
server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java

@@ -36,7 +36,6 @@ import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Scorer;
 import org.apache.lucene.search.BulkScorer;
 import org.apache.lucene.util.Bits;
-import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.ScoreScript.ExplanationHolder;
@@ -85,7 +84,7 @@ public class ScriptScoreQuery extends Query {
         }
         boolean needsScore = scriptBuilder.needs_score();
         ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
-        Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, boost);
+        Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, 1.0f);
 
         return new Weight(this){
             @Override
@@ -95,7 +94,7 @@ public class ScriptScoreQuery extends Query {
                     if (subQueryBulkScorer == null) {
                         return null;
                     }
-                    return new ScriptScoreBulkScorer(subQueryBulkScorer, subQueryScoreMode, makeScoreScript(context));
+                    return new ScriptScoreBulkScorer(subQueryBulkScorer, subQueryScoreMode, makeScoreScript(context), boost);
                 } else {
                     return super.bulkScorer(context);
                 }
@@ -112,7 +111,7 @@ public class ScriptScoreQuery extends Query {
                 if (subQueryScorer == null) {
                     return null;
                 }
-                Scorer scriptScorer = new ScriptScorer(this, makeScoreScript(context), subQueryScorer, subQueryScoreMode, null);
+                Scorer scriptScorer = new ScriptScorer(this, makeScoreScript(context), subQueryScorer, subQueryScoreMode, boost, null);
                 if (minScore != null) {
                     scriptScorer = new MinScoreScorer(this, scriptScorer, minScore);
                 }
@@ -127,11 +126,11 @@ public class ScriptScoreQuery extends Query {
                 }
                 ExplanationHolder explanationHolder = new ExplanationHolder();
                 Scorer scorer = new ScriptScorer(this, makeScoreScript(context),
-                    subQueryWeight.scorer(context), subQueryScoreMode, explanationHolder);
+                    subQueryWeight.scorer(context), subQueryScoreMode, 1f, explanationHolder);
                 int newDoc = scorer.iterator().advance(doc);
                 assert doc == newDoc; // subquery should have already matched above
-                float score = scorer.score();
-                
+                float score = scorer.score(); // score without boost
+
                 Explanation explanation = explanationHolder.get(score, needsScore ? subQueryExplanation : null);
                 if (explanation == null) {
                     // no explanation provided by user; give a simple one
@@ -143,7 +142,10 @@ public class ScriptScoreQuery extends Query {
                         explanation = Explanation.match(score, desc);
                     }
                 }
-                
+                if (boost != 1f) {
+                    explanation = Explanation.match(boost * explanation.getValue().floatValue(), "Boosted score, product of:",
+                        Explanation.match(boost, "boost"), explanation);
+                }
                 if (minScore != null && minScore > explanation.getValue().floatValue()) {
                     explanation = Explanation.noMatch("Score value is too low, expected at least " + minScore +
                         " but got " + explanation.getValue(), explanation);
@@ -203,16 +205,18 @@ public class ScriptScoreQuery extends Query {
     private static class ScriptScorer extends Scorer {
         private final ScoreScript scoreScript;
         private final Scorer subQueryScorer;
+        private final float boost;
         private final ExplanationHolder explanation;
 
         ScriptScorer(Weight weight, ScoreScript scoreScript, Scorer subQueryScorer,
-                ScoreMode subQueryScoreMode, ExplanationHolder explanation) {
+                ScoreMode subQueryScoreMode, float boost, ExplanationHolder explanation) {
             super(weight);
             this.scoreScript = scoreScript;
             if (subQueryScoreMode == ScoreMode.COMPLETE) {
                 scoreScript.setScorer(subQueryScorer);
             }
             this.subQueryScorer = subQueryScorer;
+            this.boost = boost;
             this.explanation = explanation;
         }
 
@@ -221,12 +225,13 @@ public class ScriptScoreQuery extends Query {
             int docId = docID();
             scoreScript.setDocument(docId);
             float score = (float) scoreScript.execute(explanation);
-            if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
-                throw new ElasticsearchException(
-                    "script_score query returned an invalid score [" + score + "] for doc [" + docId + "].");
+            if (score < 0f || Float.isNaN(score)) {
+                throw new IllegalArgumentException("script_score script returned an invalid score [" + score + "] " +
+                    "for doc [" + docId + "]. Must be a non-negative score!");
             }
-            return score;
+            return score * boost;
         }
+
         @Override
         public int docID() {
             return subQueryScorer.docID();
@@ -247,15 +252,17 @@ public class ScriptScoreQuery extends Query {
     private static class ScriptScorable extends Scorable {
         private final ScoreScript scoreScript;
         private final Scorable subQueryScorer;
+        private final float boost;
         private final ExplanationHolder explanation;
 
         ScriptScorable(ScoreScript scoreScript, Scorable subQueryScorer,
-                ScoreMode subQueryScoreMode, ExplanationHolder explanation) {
+                ScoreMode subQueryScoreMode, float boost, ExplanationHolder explanation) {
             this.scoreScript = scoreScript;
             if (subQueryScoreMode == ScoreMode.COMPLETE) {
                 scoreScript.setScorer(subQueryScorer);
             }
             this.subQueryScorer = subQueryScorer;
+            this.boost = boost;
             this.explanation = explanation;
         }
 
@@ -264,11 +271,11 @@ public class ScriptScoreQuery extends Query {
             int docId = docID();
             scoreScript.setDocument(docId);
             float score = (float) scoreScript.execute(explanation);
-            if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
-                throw new ElasticsearchException(
-                    "script_score query returned an invalid score [" + score + "] for doc [" + docId + "].");
+            if (score < 0f || Float.isNaN(score)) {
+                throw new IllegalArgumentException("script_score script returned an invalid score [" + score + "] " +
+                    "for doc [" + docId + "]. Must be a non-negative score!");
             }
-            return score;
+            return score * boost;
         }
         @Override
         public int docID() {
@@ -284,11 +291,13 @@ public class ScriptScoreQuery extends Query {
         private final BulkScorer subQueryBulkScorer;
         private final ScoreMode subQueryScoreMode;
         private final ScoreScript scoreScript;
+        private final float boost;
 
-        ScriptScoreBulkScorer(BulkScorer subQueryBulkScorer, ScoreMode subQueryScoreMode, ScoreScript scoreScript) {
+        ScriptScoreBulkScorer(BulkScorer subQueryBulkScorer, ScoreMode subQueryScoreMode, ScoreScript scoreScript, float boost) {
             this.subQueryBulkScorer = subQueryBulkScorer;
             this.subQueryScoreMode = subQueryScoreMode;
             this.scoreScript = scoreScript;
+            this.boost = boost;
         }
 
         @Override
@@ -300,7 +309,7 @@ public class ScriptScoreQuery extends Query {
             return new FilterLeafCollector(collector) {
                 @Override
                 public void setScorer(Scorable scorer) throws IOException {
-                    in.setScorer(new ScriptScorable(scoreScript, scorer, subQueryScoreMode, null));
+                    in.setScorer(new ScriptScorable(scoreScript, scorer, subQueryScoreMode, boost, null));
                 }
             };
         }