Ver código fonte

Add tests for equals and hashCode and fix FiltersFunctionScoreQuery equals and hashCode impls

Relates to #15676
Simon Willnauer 9 anos atrás
pai
commit
0a816cd343

+ 4 - 3
core/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java

@@ -369,12 +369,13 @@ public class FiltersFunctionScoreQuery extends Query {
         }
         FiltersFunctionScoreQuery other = (FiltersFunctionScoreQuery) o;
         return Objects.equals(this.subQuery, other.subQuery) && this.maxBoost == other.maxBoost &&
-                Objects.equals(this.combineFunction, other.combineFunction) && Objects.equals(this.minScore, other.minScore) &&
-                Arrays.equals(this.filterFunctions, other.filterFunctions);
+            Objects.equals(this.combineFunction, other.combineFunction) && Objects.equals(this.minScore, other.minScore) &&
+            Objects.equals(this.scoreMode, other.scoreMode) &&
+            Arrays.equals(this.filterFunctions, other.filterFunctions);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), subQuery, maxBoost, combineFunction, minScore, filterFunctions);
+        return Objects.hash(super.hashCode(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(filterFunctions));
     }
 }

+ 138 - 6
core/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreTests.java

@@ -358,11 +358,11 @@ public class FunctionScoreTests extends ESTestCase {
 
         // now test all together
         functionExplanation = getFiltersFunctionScoreExplanation(searcher
-                , RANDOM_SCORE_FUNCTION
-                , FIELD_VALUE_FACTOR_FUNCTION
-                , GAUSS_DECAY_FUNCTION
-                , EXP_DECAY_FUNCTION
-                , LIN_DECAY_FUNCTION
+            , RANDOM_SCORE_FUNCTION
+            , FIELD_VALUE_FACTOR_FUNCTION
+            , GAUSS_DECAY_FUNCTION
+            , EXP_DECAY_FUNCTION
+            , LIN_DECAY_FUNCTION
         );
 
         checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0);
@@ -398,7 +398,7 @@ public class FunctionScoreTests extends ESTestCase {
         FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[scoreFunctions.length];
         for (int i = 0; i < scoreFunctions.length; i++) {
             filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(
-                    new TermQuery(TERM), scoreFunctions[i]);
+                new TermQuery(TERM), scoreFunctions[i]);
         }
         return new FiltersFunctionScoreQuery(new TermQuery(TERM), scoreMode, filterFunctions, Float.MAX_VALUE, Float.MAX_VALUE * -1, combineFunction);
     }
@@ -610,4 +610,136 @@ public class FunctionScoreTests extends ESTestCase {
             assertNotNull(scorer.twoPhaseIterator());
         }
     }
+
+    public void testFunctionScoreHashCodeAndEquals() {
+        Float minScore = randomBoolean() ? null : 1.0f;
+        CombineFunction combineFunction = randomFrom(CombineFunction.values());
+        float maxBoost = randomBoolean() ? Float.POSITIVE_INFINITY : randomFloat();
+        ScoreFunction function = randomBoolean() ? null : new ScoreFunction(combineFunction) {
+            @Override
+            public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
+                return null;
+            }
+
+            @Override
+            public boolean needsScores() {
+                return false;
+            }
+            @Override
+            protected boolean doEquals(ScoreFunction other) {
+                return other == this;
+            }
+        };
+
+        FunctionScoreQuery q = new FunctionScoreQuery(new TermQuery(new Term("foo", "bar")), function, minScore, combineFunction, maxBoost);
+        FunctionScoreQuery q1 = new FunctionScoreQuery(new TermQuery(new Term("foo", "bar")), function, minScore, combineFunction, maxBoost);
+        assertEquals(q, q);
+        assertEquals(q.hashCode(), q.hashCode());
+        assertEquals(q, q1);
+        assertEquals(q.hashCode(), q1.hashCode());
+
+        FunctionScoreQuery diffQuery = new FunctionScoreQuery(new TermQuery(new Term("foo", "baz")), function, minScore, combineFunction, maxBoost);
+        FunctionScoreQuery diffMinScore = new FunctionScoreQuery(q.getSubQuery(), function, minScore == null ? 1.0f : null, combineFunction, maxBoost);
+        ScoreFunction otherFunciton = function == null ? new ScoreFunction(combineFunction) {
+            @Override
+            public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
+                return null;
+            }
+
+            @Override
+            public boolean needsScores() {
+                return false;
+            }
+
+            @Override
+            protected boolean doEquals(ScoreFunction other) {
+                return other == this;
+            }
+
+        } : null;
+        FunctionScoreQuery diffFunction = new FunctionScoreQuery(q.getSubQuery(), otherFunciton, minScore, combineFunction, maxBoost);
+        FunctionScoreQuery diffMaxBoost = new FunctionScoreQuery(new TermQuery(new Term("foo", "bar")), function, minScore, combineFunction, maxBoost == 1.0f ? 0.9f : 1.0f);
+        q1.setBoost(3.0f);
+        FunctionScoreQuery[] queries = new FunctionScoreQuery[] {
+            diffFunction,
+            diffMinScore,
+            diffQuery,
+            q,
+            q1,
+            diffMaxBoost
+        };
+        final int numIters = randomIntBetween(20, 100);
+        for (int i = 0; i < numIters; i++) {
+            FunctionScoreQuery left = randomFrom(queries);
+            FunctionScoreQuery right = randomFrom(queries);
+            if (left == right) {
+                assertEquals(left, right);
+                assertEquals(left.hashCode(), right.hashCode());
+            } else {
+                assertNotEquals(left + " == " + right, left, right);
+            }
+        }
+
+    }
+
+    public void testFilterFunctionScoreHashCodeAndEquals() {
+        ScoreMode mode = randomFrom(ScoreMode.values());
+        CombineFunction combineFunction = randomFrom(CombineFunction.values());
+        ScoreFunction scoreFunction = new ScoreFunction(combineFunction) {
+            @Override
+            public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
+                return null;
+            }
+
+            @Override
+            public boolean needsScores() {
+                return false;
+            }
+
+            @Override
+            protected boolean doEquals(ScoreFunction other) {
+                return other == this;
+            }
+        };
+        Float minScore = randomBoolean() ? null : 1.0f;
+        Float maxBoost = randomBoolean() ? Float.POSITIVE_INFINITY : randomFloat();
+
+        FilterFunction function = new FilterFunction(new TermQuery(new Term("filter", "query")), scoreFunction);
+        FiltersFunctionScoreQuery q = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode, new FilterFunction[] {function}, maxBoost, minScore, combineFunction);
+        FiltersFunctionScoreQuery q1 = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode, new FilterFunction[] {function}, maxBoost, minScore, combineFunction);
+        assertEquals(q, q);
+        assertEquals(q.hashCode(), q.hashCode());
+        assertEquals(q, q1);
+        assertEquals(q.hashCode(), q1.hashCode());
+        FiltersFunctionScoreQuery diffCombineFunc = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode, new FilterFunction[] {function}, maxBoost, minScore, combineFunction == CombineFunction.AVG ? CombineFunction.MAX : CombineFunction.AVG);
+        FiltersFunctionScoreQuery diffQuery = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "baz")), mode, new FilterFunction[] {function}, maxBoost, minScore, combineFunction);
+        FiltersFunctionScoreQuery diffMode = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode == ScoreMode.AVG ? ScoreMode.FIRST : ScoreMode.AVG, new FilterFunction[] {function}, maxBoost, minScore, combineFunction);
+        FiltersFunctionScoreQuery diffMaxBoost = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode, new FilterFunction[] {function}, maxBoost == 1.0f ? 0.9f : 1.0f, minScore, combineFunction);
+        FiltersFunctionScoreQuery diffMinScore = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode, new FilterFunction[] {function}, maxBoost, minScore == null ? 0.9f : null, combineFunction);
+        FilterFunction otherFunc = new FilterFunction(new TermQuery(new Term("filter", "other_query")), scoreFunction);
+        FiltersFunctionScoreQuery diffFunc = new FiltersFunctionScoreQuery(new TermQuery(new Term("foo", "bar")), mode, randomBoolean() ? new FilterFunction[] {function, otherFunc} : new FilterFunction[] {otherFunc}, maxBoost, minScore, combineFunction);
+        q1.setBoost(3.0f);
+
+        FiltersFunctionScoreQuery[] queries = new FiltersFunctionScoreQuery[] {
+            diffQuery,
+            diffMaxBoost,
+            diffMinScore,
+            diffMode,
+            diffFunc,
+            q,
+            q1,
+            diffCombineFunc
+        };
+        final int numIters = randomIntBetween(20, 100);
+        for (int i = 0; i < numIters; i++) {
+            FiltersFunctionScoreQuery left = randomFrom(queries);
+            FiltersFunctionScoreQuery right = randomFrom(queries);
+            if (left == right) {
+                assertEquals(left, right);
+                assertEquals(left.hashCode(), right.hashCode());
+            } else {
+                assertNotEquals(left + " == " + right, left, right);
+            }
+        }
+    }
 }