Browse Source

Add linear function to rank_feature query (#67438)

This adds a linear function to the set of functions available
for rank_feature query

Closes #49859
Mayya Sharipova 4 years ago
parent
commit
76482210b8

+ 40 - 4
docs/reference/query-dsl/rank-feature-query.asciidoc

@@ -27,6 +27,7 @@ query supports the following mathematical functions:
 * <<rank-feature-query-saturation,Saturation>>
 * <<rank-feature-query-logarithm,Logarithm>>
 * <<rank-feature-query-sigmoid,Sigmoid>>
+* <<rank-feature-query-linear,Linear>>
 
 If you don't know where to start, we recommend using the `saturation` function.
 If no function is provided, the `rank_feature` query uses the `saturation`
@@ -126,7 +127,7 @@ The following query searches for `2016` and boosts relevance scores based on
 
 [source,console]
 ----
-GET /test/_search 
+GET /test/_search
 {
   "query": {
     "bool": {
@@ -190,7 +191,7 @@ value of the rank feature `field`. If no function is provided, the `rank_feature
 query defaults to the `saturation` function. See
 <<rank-feature-query-saturation,Saturation>> for more information.
 
-Only one function `saturation`, `log`, or `sigmoid` can be provided.
+Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
 --
 
 `log`::
@@ -201,7 +202,7 @@ function used to boost <<relevance-scores,relevance scores>> based on the
 value of the rank feature `field`. See
 <<rank-feature-query-logarithm,Logarithm>> for more information.
 
-Only one function `saturation`, `log`, or `sigmoid` can be provided.
+Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
 --
 
 `sigmoid`::
@@ -212,7 +213,18 @@ to boost <<relevance-scores,relevance scores>> based on the value of the
 rank feature `field`. See <<rank-feature-query-sigmoid,Sigmoid>> for more
 information.
 
-Only one function `saturation`, `log`, or `sigmoid` can be provided.
+Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
+--
+
+`linear`::
++
+--
+(Optional, <<rank-feature-query-linear,function object>>) Linear function used
+to boost <<relevance-scores,relevance scores>> based on the value of the
+rank feature `field`. See <<rank-feature-query-linear,Linear>> for more
+information.
+
+Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
 --
 
 
@@ -311,3 +323,27 @@ GET /test/_search
   }
 }
 --------------------------------------------------
+[[rank-feature-query-linear]]
+===== Linear
+The `linear` function is the simplest function, and gives a score equal
+to the indexed value of `S`, where `S` is the value of the rank feature
+field.
+If a rank feature field is indexed with `"positive_score_impact": true`,
+its indexed value is equal to `S` and rounded to preserve only
+9 significant bits for the precision.
+If a rank feature field is indexed with `"positive_score_impact": false`,
+its indexed value is equal to `1/S` and rounded to preserve only 9 significant
+bits for the precision.
+
+[source,console]
+--------------------------------------------------
+GET /test/_search
+{
+  "query": {
+    "rank_feature": {
+      "field": "pagerank",
+      "linear": {}
+    }
+  }
+}
+--------------------------------------------------

+ 54 - 4
modules/mapper-extras/src/main/java/org/elasticsearch/index/query/RankFeatureQueryBuilder.java

@@ -26,6 +26,7 @@ import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.RankFeatureFieldMapper.RankFeatureFieldType;
@@ -104,7 +105,7 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
             }
 
             @Override
-            Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
+            Query toQuery(String field, String feature, boolean positiveScoreImpact) {
                 if (positiveScoreImpact == false) {
                     throw new IllegalArgumentException("Cannot use the [log] function with a field that has a negative score impact as " +
                             "it would trigger negative scores");
@@ -175,7 +176,7 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
             }
 
             @Override
-            Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
+            Query toQuery(String field, String feature, boolean positiveScoreImpact) {
                 if (pivot == null) {
                     return FeatureField.newSaturationQuery(field, feature);
                 } else {
@@ -240,10 +241,55 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
             }
 
             @Override
-            Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
+            Query toQuery(String field, String feature, boolean positiveScoreImpact) {
                 return FeatureField.newSigmoidQuery(field, feature, DEFAULT_BOOST, pivot, exp);
             }
         }
+
+        /**
+         * A scoring function that scores documents as simply {@code S}
+         * where S is the indexed value of the static feature.
+         */
+        public static class Linear extends ScoreFunction {
+
+            private static final ObjectParser<Linear, Void> PARSER = new ObjectParser<>("linear", Linear::new);
+
+            public Linear() {
+            }
+
+            private Linear(StreamInput in) {
+                this();
+            }
+
+            @Override
+            public boolean equals(Object obj) {
+                if (obj == null || getClass() != obj.getClass()) {
+                    return false;
+                }
+                return true;
+            }
+
+            @Override
+            public int hashCode() {
+                return getClass().hashCode();
+            }
+
+            @Override
+            void writeTo(StreamOutput out) throws IOException {
+                out.writeByte((byte) 3);
+            }
+
+            @Override
+            void doXContent(XContentBuilder builder) throws IOException {
+                builder.startObject("linear");
+                builder.endObject();
+            }
+
+            @Override
+            Query toQuery(String field, String feature, boolean positiveScoreImpact) {
+                return FeatureField.newLinearQuery(field, feature, DEFAULT_BOOST);
+            }
+        }
     }
 
     private static ScoreFunction readScoreFunction(StreamInput in) throws IOException {
@@ -255,6 +301,8 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
             return new ScoreFunction.Saturation(in);
         case 2:
             return new ScoreFunction.Sigmoid(in);
+        case 3:
+            return new ScoreFunction.Linear(in);
         default:
             throw new IOException("Illegal score function id: " + b);
         }
@@ -268,7 +316,7 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
                 long numNonNulls = Arrays.stream(args, 3, args.length).filter(Objects::nonNull).count();
                 final RankFeatureQueryBuilder query;
                 if (numNonNulls > 1) {
-                    throw new IllegalArgumentException("Can only specify one of [log], [saturation] and [sigmoid]");
+                    throw new IllegalArgumentException("Can only specify one of [log], [saturation], [sigmoid] and [linear]");
                 } else if (numNonNulls == 0) {
                     query = new RankFeatureQueryBuilder(field, new ScoreFunction.Saturation());
                 } else {
@@ -292,6 +340,8 @@ public final class RankFeatureQueryBuilder extends AbstractQueryBuilder<RankFeat
                 ScoreFunction.Saturation.PARSER, new ParseField("saturation"));
         PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
                 ScoreFunction.Sigmoid.PARSER, new ParseField("sigmoid"));
+        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
+                ScoreFunction.Linear.PARSER, new ParseField("linear"));
     }
 
     public static final String NAME = "rank_feature";

+ 9 - 0
modules/mapper-extras/src/main/java/org/elasticsearch/index/query/RankFeatureQueryBuilders.java

@@ -64,4 +64,13 @@ public final class RankFeatureQueryBuilders {
         return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Sigmoid(pivot, exp));
     }
 
+    /**
+     * Return a new {@link RankFeatureQueryBuilder} that will score documents as
+     * {@code S)} where S is the indexed value of the static feature.
+     * @param fieldName     field that stores features
+     */
+    public static RankFeatureQueryBuilder linear(String fieldName) {
+        return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Linear());
+    }
+
 }

+ 6 - 3
modules/mapper-extras/src/test/java/org/elasticsearch/index/query/RankFeatureQueryBuilderTests.java

@@ -60,7 +60,7 @@ public class RankFeatureQueryBuilderTests extends AbstractQueryTestCase<RankFeat
     protected RankFeatureQueryBuilder doCreateTestQueryBuilder() {
         ScoreFunction function;
         boolean mayUseNegativeField = true;
-        switch (random().nextInt(3)) {
+        switch (random().nextInt(4)) {
         case 0:
             mayUseNegativeField = false;
             function = new ScoreFunction.Log(1 + randomFloat());
@@ -75,6 +75,9 @@ public class RankFeatureQueryBuilderTests extends AbstractQueryTestCase<RankFeat
         case 2:
             function = new ScoreFunction.Sigmoid(randomFloat(), randomFloat());
             break;
+        case 3:
+            function = new ScoreFunction.Linear();
+            break;
         default:
             throw new AssertionError();
         }
@@ -106,7 +109,7 @@ public class RankFeatureQueryBuilderTests extends AbstractQueryTestCase<RankFeat
         assertEquals(FeatureField.newSaturationQuery("_feature", "my_feature_field"), parsedQuery);
     }
 
-    public void testIllegalField() throws IOException {
+    public void testIllegalField() {
         String query = "{\n" +
                 "    \"rank_feature\" : {\n" +
                 "        \"field\": \"" + TEXT_FIELD_NAME + "\"\n" +
@@ -118,7 +121,7 @@ public class RankFeatureQueryBuilderTests extends AbstractQueryTestCase<RankFeat
             e.getMessage());
     }
 
-    public void testIllegalCombination() throws IOException {
+    public void testIllegalCombination() {
         String query = "{\n" +
                 "    \"rank_feature\" : {\n" +
                 "        \"field\": \"my_negative_feature_field\",\n" +

+ 52 - 11
modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/rank_feature/10_basic.yml

@@ -37,7 +37,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -46,7 +46,7 @@ setup:
                 scaling_factor: 3
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
@@ -59,7 +59,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -68,7 +68,7 @@ setup:
                 pivot: 20
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
@@ -81,7 +81,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -91,7 +91,27 @@ setup:
                 exponent: 0.6
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
+
+  - match:
+      hits.hits.0._id: "2"
+
+  - match:
+      hits.hits.1._id: "1"
+
+---
+"Positive linear":
+  - do:
+      search:
+        index: test
+        body:
+          query:
+            rank_feature:
+              field: pagerank
+              linear: {}
+
+  - match:
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
@@ -105,7 +125,7 @@ setup:
   - do:
       catch: bad_request
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -118,7 +138,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -127,7 +147,7 @@ setup:
                 pivot: 20
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
@@ -140,7 +160,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -150,7 +170,28 @@ setup:
                 exponent: 0.6
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
+
+  - match:
+      hits.hits.0._id: "2"
+
+  - match:
+      hits.hits.1._id: "1"
+
+---
+"Negative linear":
+
+  - do:
+      search:
+        index: test
+        body:
+          query:
+            rank_feature:
+              field: url_length
+              linear: {}
+
+  - match:
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"

+ 31 - 6
modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/rank_features/10_basic.yml

@@ -36,7 +36,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -45,7 +45,7 @@ setup:
                 scaling_factor: 3
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
@@ -58,7 +58,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -67,7 +67,7 @@ setup:
                 pivot: 20
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
@@ -80,7 +80,7 @@ setup:
 
   - do:
       search:
-        rest_total_hits_as_int: true
+        index: test
         body:
           query:
             rank_feature:
@@ -90,10 +90,35 @@ setup:
                 exponent: 0.6
 
   - match:
-      hits.total: 2
+      hits.total.value: 2
 
   - match:
       hits.hits.0._id: "2"
 
   - match:
       hits.hits.1._id: "1"
+
+---
+"Linear":
+
+  - do:
+      search:
+        index: test
+        body:
+          query:
+            rank_feature:
+              field: tags.bar
+              linear: {}
+
+  - match:
+      hits.total.value: 2
+
+  - match:
+      hits.hits.0._id: "2"
+  - match:
+      hits.hits.0._score: 6.0
+
+  - match:
+      hits.hits.1._id: "1"
+  - match:
+      hits.hits.1._score: 5.0