浏览代码

Fix LTR query feature with phrases (and two-phase) queries (#125103)

Query features should verify that docs match the two-phase iterator.
Jim Ferenczi 6 月之前
父节点
当前提交
7c5be1257f

+ 5 - 0
docs/changelog/125103.yaml

@@ -0,0 +1,5 @@
+pr: 125103
+summary: Fix LTR query feature with phrases (and two-phase) queries
+area: Ranking
+type: bug
+issues: []

+ 28 - 21
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java

@@ -15,7 +15,6 @@ import org.apache.lucene.search.Scorer;
 import org.apache.lucene.search.Weight;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 
@@ -25,11 +24,11 @@ import java.util.Map;
  * respective feature name.
  */
 public class QueryFeatureExtractor implements FeatureExtractor {
-
     private final List<String> featureNames;
     private final List<Weight> weights;
-    private final List<Scorer> scorers;
-    private DisjunctionDISIApproximation rankerIterator;
+
+    private final DisiPriorityQueue subScorers;
+    private DisjunctionDISIApproximation approximation;
 
     public QueryFeatureExtractor(List<String> featureNames, List<Weight> weights) {
         if (featureNames.size() != weights.size()) {
@@ -37,40 +36,40 @@ public class QueryFeatureExtractor implements FeatureExtractor {
         }
         this.featureNames = featureNames;
         this.weights = weights;
-        this.scorers = new ArrayList<>(weights.size());
+        this.subScorers = new DisiPriorityQueue(weights.size());
     }
 
     @Override
     public void setNextReader(LeafReaderContext segmentContext) throws IOException {
-        DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
-        scorers.clear();
-        for (Weight weight : weights) {
+        subScorers.clear();
+        for (int i = 0; i < weights.size(); i++) {
+            var weight = weights.get(i);
             if (weight == null) {
-                scorers.add(null);
                 continue;
             }
             Scorer scorer = weight.scorer(segmentContext);
             if (scorer != null) {
-                disiPriorityQueue.add(new DisiWrapper(scorer, false));
+                subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i)));
             }
-            scorers.add(scorer);
         }
-
-        rankerIterator = disiPriorityQueue.size() > 0 ? new DisjunctionDISIApproximation(disiPriorityQueue) : null;
+        approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null;
     }
 
     @Override
     public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
-        if (rankerIterator == null) {
+        if (approximation == null || approximation.docID() > docId) {
             return;
         }
-
-        rankerIterator.advance(docId);
-        for (int i = 0; i < featureNames.size(); i++) {
-            Scorer scorer = scorers.get(i);
-            // Do we have a scorer, and does it match the provided document?
-            if (scorer != null && scorer.docID() == docId) {
-                featureMap.put(featureNames.get(i), scorer.score());
+        if (approximation.docID() < docId) {
+            approximation.advance(docId);
+        }
+        if (approximation.docID() != docId) {
+            return;
+        }
+        var w = (FeatureDisiWrapper) subScorers.topList();
+        for (; w != null; w = (FeatureDisiWrapper) w.next) {
+            if (w.twoPhaseView == null || w.twoPhaseView.matches()) {
+                featureMap.put(w.featureName, w.scorable.score());
             }
         }
     }
@@ -80,4 +79,12 @@ public class QueryFeatureExtractor implements FeatureExtractor {
         return featureNames;
     }
 
+    private static class FeatureDisiWrapper extends DisiWrapper {
+        final String featureName;
+
+        FeatureDisiWrapper(Scorer scorer, String featureName) {
+            super(scorer, false);
+            this.featureName = featureName;
+        }
+    }
 }

+ 106 - 87
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java

@@ -12,7 +12,7 @@ import org.apache.lucene.document.Field;
 import org.apache.lucene.document.IntField;
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.index.NoMergePolicy;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Weight;
@@ -43,13 +43,11 @@ import static org.mockito.Mockito.mock;
 
 public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {
 
-    private Directory dir;
-    private IndexReader reader;
-    private IndexSearcher searcher;
-
-    private void addDocs(String[] textValues, int[] numberValues) throws IOException {
-        dir = newDirectory();
-        try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) {
+    private IndexReader addDocs(Directory dir, String[] textValues, int[] numberValues) throws IOException {
+        var config = newIndexWriterConfig();
+        // override the merge policy to ensure that docs remain in the same ingestion order
+        config.setMergePolicy(newLogMergePolicy(random()));
+        try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir, config)) {
             for (int i = 0; i < textValues.length; i++) {
                 Document doc = new Document();
                 doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO));
@@ -59,98 +57,119 @@ public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {
                     indexWriter.flush();
                 }
             }
-            reader = indexWriter.getReader();
+            return indexWriter.getReader();
         }
-        searcher = newSearcher(reader);
-        searcher.setSimilarity(new ClassicSimilarity());
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98127")
     public void testQueryExtractor() throws IOException {
-        addDocs(
-            new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
-            new int[] { 5, 10, 12, 11 }
-        );
-        QueryRewriteContext ctx = createQueryRewriteContext();
-        List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
-            new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
-                .rewrite(ctx),
-            new QueryExtractorBuilder(
-                "number_score",
-                QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
-            ).rewrite(ctx),
-            new QueryExtractorBuilder(
-                "matching_none",
-                QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
-            ).rewrite(ctx),
-            new QueryExtractorBuilder(
-                "matching_missing_field",
-                QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
-            ).rewrite(ctx)
-        );
-        SearchExecutionContext dummySEC = createSearchExecutionContext();
-        List<Weight> weights = new ArrayList<>();
-        List<String> featureNames = new ArrayList<>();
-        for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
-            Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
-            Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
-            weights.add(weight);
-            featureNames.add(qeb.featureName());
-        }
-        QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
-        List<Map<String, Object>> extractedFeatures = new ArrayList<>();
-        for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
-            int maxDoc = leafReaderContext.reader().maxDoc();
-            queryFeatureExtractor.setNextReader(leafReaderContext);
-            for (int i = 0; i < maxDoc; i++) {
-                Map<String, Object> featureMap = new HashMap<>();
-                queryFeatureExtractor.addFeatures(featureMap, i);
-                extractedFeatures.add(featureMap);
+        try (var dir = newDirectory()) {
+            try (
+                var reader = addDocs(
+                    dir,
+                    new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
+                    new int[] { 5, 10, 12, 11 }
+                )
+            ) {
+                var searcher = newSearcher(reader);
+                searcher.setSimilarity(new ClassicSimilarity());
+                QueryRewriteContext ctx = createQueryRewriteContext();
+                List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
+                    new QueryExtractorBuilder(
+                        "text_score",
+                        QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox"))
+                    ).rewrite(ctx),
+                    new QueryExtractorBuilder(
+                        "number_score",
+                        QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
+                    ).rewrite(ctx),
+                    new QueryExtractorBuilder(
+                        "matching_none",
+                        QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
+                    ).rewrite(ctx),
+                    new QueryExtractorBuilder(
+                        "matching_missing_field",
+                        QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
+                    ).rewrite(ctx),
+                    new QueryExtractorBuilder(
+                        "phrase_score",
+                        QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
+                    ).rewrite(ctx)
+                );
+                SearchExecutionContext dummySEC = createSearchExecutionContext();
+                List<Weight> weights = new ArrayList<>();
+                List<String> featureNames = new ArrayList<>();
+                for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
+                    Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
+                    Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
+                    weights.add(weight);
+                    featureNames.add(qeb.featureName());
+                }
+                QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
+                List<Map<String, Object>> extractedFeatures = new ArrayList<>();
+                for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
+                    int maxDoc = leafReaderContext.reader().maxDoc();
+                    queryFeatureExtractor.setNextReader(leafReaderContext);
+                    for (int i = 0; i < maxDoc; i++) {
+                        Map<String, Object> featureMap = new HashMap<>();
+                        queryFeatureExtractor.addFeatures(featureMap, i);
+                        extractedFeatures.add(featureMap);
+                    }
+                }
+                assertThat(extractedFeatures, hasSize(4));
+                // Should never add features for queries that don't match a document or on documents where the field is missing
+                for (Map<String, Object> features : extractedFeatures) {
+                    assertThat(features, not(hasKey("matching_none")));
+                    assertThat(features, not(hasKey("matching_missing_field")));
+                }
+                // First two only match the text field
+                assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
+                assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
+                assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
+                assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
+                assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
+                assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));
+
+                // Only matches the range query
+                assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
+                assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
+                assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));
+
+                // No query matches
+                assertThat(extractedFeatures.get(3), anEmptyMap());
             }
         }
-        assertThat(extractedFeatures, hasSize(4));
-        // Should never add features for queries that don't match a document or on documents where the field is missing
-        for (Map<String, Object> features : extractedFeatures) {
-            assertThat(features, not(hasKey("matching_none")));
-            assertThat(features, not(hasKey("matching_missing_field")));
-        }
-        // First two only match the text field
-        assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
-        assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
-        assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
-        assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
-        // Only matches the range query
-        assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
-        assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
-        // No query matches
-        assertThat(extractedFeatures.get(3), anEmptyMap());
-        reader.close();
-        dir.close();
     }
 
     public void testEmptyDisiPriorityQueue() throws IOException {
-        addDocs(
-            new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
-            new int[] { 5, 10, 12, 11 }
-        );
+        try (var dir = newDirectory()) {
+            var config = newIndexWriterConfig();
+            config.setMergePolicy(NoMergePolicy.INSTANCE);
+            try (
+                var reader = addDocs(
+                    dir,
+                    new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
+                    new int[] { 5, 10, 12, 11 }
+                )
+            ) {
 
-        // Scorers returned by weights are null
-        List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
-        List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
+                var searcher = newSearcher(reader);
+                searcher.setSimilarity(new ClassicSimilarity());
 
-        QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
+                // Scorers returned by weights are null
+                List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
+                List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
 
-        for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
-            int maxDoc = leafReaderContext.reader().maxDoc();
-            featureExtractor.setNextReader(leafReaderContext);
-            for (int i = 0; i < maxDoc; i++) {
-                Map<String, Object> featureMap = new HashMap<>();
-                featureExtractor.addFeatures(featureMap, i);
-                assertThat(featureMap, anEmptyMap());
+                QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
+                for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
+                    int maxDoc = leafReaderContext.reader().maxDoc();
+                    featureExtractor.setNextReader(leafReaderContext);
+                    for (int i = 0; i < maxDoc; i++) {
+                        Map<String, Object> featureMap = new HashMap<>();
+                        featureExtractor.addFeatures(featureMap, i);
+                        assertThat(featureMap, anEmptyMap());
+                    }
+                }
             }
         }
-
-        reader.close();
-        dir.close();
     }
 }