Browse Source

Drop boost from runtime distance feature query (#63949)

This drops the `boost` parameter of the `distance_feature` query builder
internally, relying on our query building infrastructure to wrap the
query in a `boosting` query.

Relates to #63767
Nik Everett 5 years ago
parent
commit
f2bcc77586

+ 3 - 2
server/src/main/java/org/elasticsearch/index/mapper/DateFieldMapper.java

@@ -432,10 +432,11 @@ public final class DateFieldMapper extends ParametrizedFieldMapper {
         }
 
         @Override
-        public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
+        public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
             long originLong = parseToLong(origin, true, null, null, context::nowInMillis);
             TimeValue pivotTime = TimeValue.parseTimeValue(pivot, "distance_feature.pivot");
-            return resolution.distanceFeatureQuery(name(), boost, originLong, pivotTime);
+            // As we already apply boost in AbstractQueryBuilder::toQuery, we always passing a boost of 1.0 to distanceFeatureQuery
+            return resolution.distanceFeatureQuery(name(), 1.0f, originLong, pivotTime);
         }
 
         @Override

+ 3 - 2
server/src/main/java/org/elasticsearch/index/mapper/GeoPointFieldMapper.java

@@ -194,7 +194,7 @@ public class GeoPointFieldMapper extends AbstractPointGeometryFieldMapper<List<P
         }
 
         @Override
-        public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
+        public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
             GeoPoint originGeoPoint;
             if (origin instanceof GeoPoint) {
                 originGeoPoint = (GeoPoint) origin;
@@ -205,7 +205,8 @@ public class GeoPointFieldMapper extends AbstractPointGeometryFieldMapper<List<P
                     "Must be of type [geo_point] or [string] for geo_point fields!");
             }
             double pivotDouble = DistanceUnit.DEFAULT.parse(pivot, DistanceUnit.DEFAULT);
-            return LatLonPoint.newDistanceFeatureQuery(name(), boost, originGeoPoint.lat(), originGeoPoint.lon(), pivotDouble);
+            // As we already apply boost in AbstractQueryBuilder::toQuery, we always passing a boost of 1.0 to distanceFeatureQuery
+            return LatLonPoint.newDistanceFeatureQuery(name(), 1.0f, originGeoPoint.lat(), originGeoPoint.lon(), pivotDouble);
         }
     }
 

+ 1 - 1
server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java

@@ -294,7 +294,7 @@ public abstract class MappedFieldType {
             + "] which is of type [" + typeName() + "]");
     }
 
-    public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
+    public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
         throw new IllegalArgumentException("Illegal data type of [" + typeName() + "]!"+
             "[" + DistanceFeatureQueryBuilder.NAME + "] query can only be run on a date, date_nanos or geo_point field type!");
     }

+ 1 - 2
server/src/main/java/org/elasticsearch/index/query/DistanceFeatureQueryBuilder.java

@@ -112,8 +112,7 @@ public class DistanceFeatureQueryBuilder extends AbstractQueryBuilder<DistanceFe
         if (fieldType == null) {
             return Queries.newMatchNoDocsQuery("Can't run [" + NAME + "] query on unmapped fields!");
         }
-        // As we already apply boost in AbstractQueryBuilder::toQuery, we always passing a boost of 1.0 to distanceFeatureQuery
-        return fieldType.distanceFeatureQuery(origin.origin(), pivot, 1.0f, context);
+        return fieldType.distanceFeatureQuery(origin.origin(), pivot, context);
     }
 
     String fieldName() {

+ 2 - 3
x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/mapper/DateScriptFieldType.java

@@ -81,7 +81,7 @@ public class DateScriptFieldType extends AbstractScriptFieldType<DateFieldScript
     }
 
     @Override
-    public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
+    public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
         checkAllowExpensiveQueries(context);
         return DateFieldType.handleNow(context, now -> {
             long originLong = DateFieldType.parseToLong(
@@ -98,8 +98,7 @@ public class DateScriptFieldType extends AbstractScriptFieldType<DateFieldScript
                 leafFactory(context)::newInstance,
                 name(),
                 originLong,
-                pivotTime.getMillis(),
-                boost
+                pivotTime.getMillis()
             );
         });
     }

+ 7 - 16
x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQuery.java

@@ -27,20 +27,17 @@ import java.util.function.Function;
 public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFieldQuery<AbstractLongFieldScript> {
     private final long origin;
     private final long pivot;
-    private final float boost;
 
     public LongScriptFieldDistanceFeatureQuery(
         Script script,
         Function<LeafReaderContext, AbstractLongFieldScript> leafFactory,
         String fieldName,
         long origin,
-        long pivot,
-        float boost
+        long pivot
     ) {
         super(script, fieldName, leafFactory);
         this.origin = origin;
         this.pivot = pivot;
-        this.boost = boost;
     }
 
     @Override
@@ -70,12 +67,11 @@ public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFie
                 AbstractLongFieldScript script = scriptContextFunction().apply(context);
                 script.runForDoc(doc);
                 long value = valueWithMinAbsoluteDistance(script);
-                float weight = LongScriptFieldDistanceFeatureQuery.this.boost * boost;
-                float score = score(weight, distanceFor(value));
+                float score = score(boost, distanceFor(value));
                 return Explanation.match(
                     score,
                     "Distance score, computed as weight * pivot / (pivot + abs(value - origin)) from:",
-                    Explanation.match(weight, "weight"),
+                    Explanation.match(boost, "weight"),
                     Explanation.match(pivot, "pivot"),
                     Explanation.match(origin, "origin"),
                     Explanation.match(value, "current value")
@@ -105,7 +101,7 @@ public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFie
                 }
             };
             disi = TwoPhaseIterator.asDocIdSetIterator(twoPhase);
-            this.weight = LongScriptFieldDistanceFeatureQuery.this.boost * boost;
+            this.weight = boost;
         }
 
         @Override
@@ -179,15 +175,14 @@ public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFie
         }
         b.append(getClass().getSimpleName());
         b.append("(origin=").append(origin);
-        b.append(",pivot=").append(pivot);
-        b.append(",boost=").append(boost).append(")");
+        b.append(",pivot=").append(pivot).append(")");
         return b.toString();
 
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), origin, pivot, boost);
+        return Objects.hash(super.hashCode(), origin, pivot);
     }
 
     @Override
@@ -196,7 +191,7 @@ public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFie
             return false;
         }
         LongScriptFieldDistanceFeatureQuery other = (LongScriptFieldDistanceFeatureQuery) obj;
-        return origin == other.origin && pivot == other.pivot && boost == other.boost;
+        return origin == other.origin && pivot == other.pivot;
     }
 
     @Override
@@ -214,8 +209,4 @@ public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFie
     long pivot() {
         return pivot;
     }
-
-    float boost() {
-        return boost;
-    }
 }

+ 2 - 2
x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/DateScriptFieldTypeTests.java

@@ -199,7 +199,7 @@ public class DateScriptFieldTypeTests extends AbstractNonTextScriptFieldTypeTest
             );
             try (DirectoryReader reader = iw.getReader()) {
                 IndexSearcher searcher = newSearcher(reader);
-                Query query = simpleMappedFieldType().distanceFeatureQuery(1595432181354L, "1ms", 1, mockContext());
+                Query query = simpleMappedFieldType().distanceFeatureQuery(1595432181354L, "1ms", mockContext());
                 TopDocs docs = searcher.search(query, 4);
                 assertThat(docs.scoreDocs, arrayWithSize(3));
                 assertThat(readSource(reader, docs.scoreDocs[0].doc), equalTo("{\"timestamp\": [1595432181354]}"));
@@ -228,7 +228,7 @@ public class DateScriptFieldTypeTests extends AbstractNonTextScriptFieldTypeTest
     }
 
     private Query randomDistanceFeatureQuery(MappedFieldType ft, QueryShardContext ctx) {
-        return ft.distanceFeatureQuery(randomDate(), randomTimeValue(), randomFloat(), ctx);
+        return ft.distanceFeatureQuery(randomDate(), randomTimeValue(), ctx);
     }
 
     @Override

+ 11 - 23
x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQueryTests.java

@@ -35,19 +35,12 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
     protected LongScriptFieldDistanceFeatureQuery createTestInstance() {
         long origin = randomLong();
         long pivot = randomValueOtherThan(origin, ESTestCase::randomLong);
-        return new LongScriptFieldDistanceFeatureQuery(randomScript(), leafFactory, randomAlphaOfLength(5), origin, pivot, randomFloat());
+        return new LongScriptFieldDistanceFeatureQuery(randomScript(), leafFactory, randomAlphaOfLength(5), origin, pivot);
     }
 
     @Override
     protected LongScriptFieldDistanceFeatureQuery copy(LongScriptFieldDistanceFeatureQuery orig) {
-        return new LongScriptFieldDistanceFeatureQuery(
-            orig.script(),
-            leafFactory,
-            orig.fieldName(),
-            orig.origin(),
-            orig.pivot(),
-            orig.boost()
-        );
+        return new LongScriptFieldDistanceFeatureQuery(orig.script(), leafFactory, orig.fieldName(), orig.origin(), orig.pivot());
     }
 
     @Override
@@ -56,8 +49,7 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
         String fieldName = orig.fieldName();
         long origin = orig.origin();
         long pivot = orig.pivot();
-        float boost = orig.boost();
-        switch (randomInt(4)) {
+        switch (randomInt(3)) {
             case 0:
                 script = randomValueOtherThan(script, this::randomScript);
                 break;
@@ -70,13 +62,10 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
             case 3:
                 pivot = randomValueOtherThan(origin, () -> randomValueOtherThan(orig.pivot(), ESTestCase::randomLong));
                 break;
-            case 4:
-                boost = randomValueOtherThan(boost, ESTestCase::randomFloat);
-                break;
             default:
                 fail();
         }
-        return new LongScriptFieldDistanceFeatureQuery(script, leafFactory, fieldName, origin, pivot, boost);
+        return new LongScriptFieldDistanceFeatureQuery(script, leafFactory, fieldName, origin, pivot);
     }
 
     @Override
@@ -105,12 +94,13 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
                     leafFactory,
                     "test",
                     1595432181351L,
-                    6L,
-                    between(1, 100)
+                    3L
                 );
-                TopDocs td = searcher.search(query, 1);
-                assertThat(td.scoreDocs[0].score, equalTo(query.boost()));
+                TopDocs td = searcher.search(query, 2);
+                assertThat(td.scoreDocs[0].score, equalTo(1.0f));
                 assertThat(td.scoreDocs[0].doc, equalTo(1));
+                assertThat(td.scoreDocs[1].score, equalTo(.5f));
+                assertThat(td.scoreDocs[1].doc, equalTo(0));
             }
         }
     }
@@ -124,7 +114,7 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
                 float boost = randomFloat();
                 assertThat(
                     query.createWeight(searcher, ScoreMode.COMPLETE, boost).scorer(reader.leaves().get(0)).getMaxScore(randomInt()),
-                    equalTo(query.boost() * boost)
+                    equalTo(boost)
                 );
             }
         }
@@ -134,9 +124,7 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
     protected void assertToString(LongScriptFieldDistanceFeatureQuery query) {
         assertThat(
             query.toString(query.fieldName()),
-            equalTo(
-                "LongScriptFieldDistanceFeatureQuery(origin=" + query.origin() + ",pivot=" + query.pivot() + ",boost=" + query.boost() + ")"
-            )
+            equalTo("LongScriptFieldDistanceFeatureQuery(origin=" + query.origin() + ",pivot=" + query.pivot() + ")")
         );
     }