Browse Source

Move SparseVectorQueryBuilder and TextExpansionQueryBuilder to x-pack core (#117857) (#117896)

This commit moves the SparseVectorQueryBuilder and TextExpansionQueryBuilder classes to the x-pack core module, enabling other modules to utilize these query builders.
Additionally, it introduces a SparseVectorQueryWrapper to extract sparse vector queries from standard Lucene queries.
This is needed for supporting semantic highlighting with sparse vector fields as follow up.
Jim Ferenczi 10 months ago
parent
commit
2cdc2891f4

+ 10 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

@@ -71,6 +71,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
+import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
+import org.elasticsearch.xpack.core.ml.search.TextExpansionQueryBuilder;
 import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;
 import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage;
 import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage;
@@ -398,6 +400,14 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, SearchPlu
     @Override
     public List<SearchPlugin.QuerySpec<?>> getQueries() {
         return List.of(
+            new QuerySpec<>(SparseVectorQueryBuilder.NAME, SparseVectorQueryBuilder::new, SparseVectorQueryBuilder::fromXContent),
+            new QuerySpec<QueryBuilder>(
+                TextExpansionQueryBuilder.NAME,
+                TextExpansionQueryBuilder::new,
+                TextExpansionQueryBuilder::fromXContent
+            ),
+            // TODO: The WeightedTokensBuilder is slated for removal after the SparseVectorQueryBuilder is available.
+            // The logic to create a Boolean query based on weighted tokens will remain and/or be moved to server.
             new SearchPlugin.QuerySpec<QueryBuilder>(
                 WeightedTokensQueryBuilder.NAME,
                 WeightedTokensQueryBuilder::new,

+ 2 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilder.java → x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.queries;
+package org.elasticsearch.xpack.core.ml.search;
 
 import org.apache.lucene.search.MatchNoDocsQuery;
 import org.apache.lucene.search.Query;
@@ -33,9 +33,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
-import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
-import org.elasticsearch.xpack.core.ml.search.WeightedToken;
-import org.elasticsearch.xpack.core.ml.search.WeightedTokensUtils;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -210,7 +207,7 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
 
         return (shouldPruneTokens)
             ? WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, queryVectors, ft, context)
-            : WeightedTokensUtils.queryBuilderWithAllTokens(queryVectors, ft, context);
+            : WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, queryVectors, ft, context);
     }
 
     @Override

+ 77 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryWrapper.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.search;
+
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+import org.elasticsearch.index.query.SearchExecutionContext;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * A wrapper class for the Lucene query generated by {@link SparseVectorQueryBuilder#toQuery(SearchExecutionContext)}.
+ * This wrapper facilitates the extraction of the complete sparse vector query using a {@link QueryVisitor}.
+ */
+public class SparseVectorQueryWrapper extends Query {
+    private final String fieldName;
+    private final Query termsQuery;
+
+    public SparseVectorQueryWrapper(String fieldName, Query termsQuery) {
+        this.fieldName = fieldName;
+        this.termsQuery = termsQuery;
+    }
+
+    public Query getTermsQuery() {
+        return termsQuery;
+    }
+
+    @Override
+    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+        var rewrite = termsQuery.rewrite(indexSearcher);
+        if (rewrite != termsQuery) {
+            return new SparseVectorQueryWrapper(fieldName, rewrite);
+        }
+        return this;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+        return termsQuery.createWeight(searcher, scoreMode, boost);
+    }
+
+    @Override
+    public String toString(String field) {
+        return termsQuery.toString(field);
+    }
+
+    @Override
+    public void visit(QueryVisitor visitor) {
+        if (visitor.acceptField(fieldName)) {
+            termsQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
+        }
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (sameClassAs(obj) == false) {
+            return false;
+        }
+        SparseVectorQueryWrapper that = (SparseVectorQueryWrapper) obj;
+        return fieldName.equals(that.fieldName) && termsQuery.equals(that.termsQuery);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(classHash(), fieldName, termsQuery);
+    }
+}

+ 1 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java → x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.queries;
+package org.elasticsearch.xpack.core.ml.search;
 
 import org.apache.lucene.search.Query;
 import org.apache.lucene.util.SetOnce;
@@ -32,8 +32,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
-import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
-import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;
 
 import java.io.IOException;
 import java.util.List;

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java

@@ -125,7 +125,7 @@ public class WeightedTokensQueryBuilder extends AbstractQueryBuilder<WeightedTok
         }
 
         return (this.tokenPruningConfig == null)
-            ? WeightedTokensUtils.queryBuilderWithAllTokens(tokens, ft, context)
+            ? WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, tokens, ft, context)
             : WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, tokens, ft, context);
     }
 

+ 8 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java

@@ -24,13 +24,18 @@ public final class WeightedTokensUtils {
 
     private WeightedTokensUtils() {}
 
-    public static Query queryBuilderWithAllTokens(List<WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) {
+    public static Query queryBuilderWithAllTokens(
+        String fieldName,
+        List<WeightedToken> tokens,
+        MappedFieldType ft,
+        SearchExecutionContext context
+    ) {
         var qb = new BooleanQuery.Builder();
 
         for (var token : tokens) {
             qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
         }
-        return qb.setMinimumNumberShouldMatch(1).build();
+        return new SparseVectorQueryWrapper(fieldName, qb.setMinimumNumberShouldMatch(1).build());
     }
 
     public static Query queryBuilderWithPrunedTokens(
@@ -64,7 +69,7 @@ public final class WeightedTokensUtils {
             }
         }
 
-        return qb.setMinimumNumberShouldMatch(1).build();
+        return new SparseVectorQueryWrapper(fieldName, qb.setMinimumNumberShouldMatch(1).build());
     }
 
     /**

+ 11 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/SparseVectorQueryBuilderTests.java → x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.queries;
+package org.elasticsearch.xpack.core.ml.search;
 
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.FeatureField;
@@ -40,9 +40,6 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
-import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
-import org.elasticsearch.xpack.core.ml.search.WeightedToken;
-import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.io.IOException;
 import java.lang.reflect.Method;
@@ -50,7 +47,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 
-import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD;
+import static org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD;
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.Matchers.either;
 import static org.hamcrest.Matchers.hasSize;
@@ -102,7 +99,7 @@ public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseV
 
     @Override
     protected Collection<Class<? extends Plugin>> getPlugins() {
-        return List.of(MachineLearning.class, MapperExtrasPlugin.class, XPackClientPlugin.class);
+        return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class);
     }
 
     @Override
@@ -156,8 +153,10 @@ public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseV
 
     @Override
     protected void doAssertLuceneQuery(SparseVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
-        assertThat(query, instanceOf(BooleanQuery.class));
-        BooleanQuery booleanQuery = (BooleanQuery) query;
+        assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
+        var sparseQuery = (SparseVectorQueryWrapper) query;
+        assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
+        BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
         assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1);
         assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS));
 
@@ -233,11 +232,13 @@ public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseV
 
     private void testDoToQuery(SparseVectorQueryBuilder queryBuilder, SearchExecutionContext context) throws IOException {
         Query query = queryBuilder.doToQuery(context);
+        assertTrue(query instanceof SparseVectorQueryWrapper);
+        var sparseQuery = (SparseVectorQueryWrapper) query;
         if (queryBuilder.shouldPruneTokens()) {
             // It's possible that all documents were pruned for aggressive pruning configurations
-            assertTrue(query instanceof BooleanQuery || query instanceof MatchNoDocsQuery);
+            assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery || sparseQuery.getTermsQuery() instanceof MatchNoDocsQuery);
         } else {
-            assertTrue(query instanceof BooleanQuery);
+            assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery);
         }
     }
 

+ 6 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java → x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.queries;
+package org.elasticsearch.xpack.core.ml.search;
 
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.FeatureField;
@@ -35,10 +35,6 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
-import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
-import org.elasticsearch.xpack.core.ml.search.WeightedToken;
-import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;
-import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.io.IOException;
 import java.lang.reflect.Method;
@@ -77,7 +73,7 @@ public class TextExpansionQueryBuilderTests extends AbstractQueryTestCase<TextEx
 
     @Override
     protected Collection<Class<? extends Plugin>> getPlugins() {
-        return List.of(MachineLearning.class, MapperExtrasPlugin.class, XPackClientPlugin.class);
+        return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class);
     }
 
     @Override
@@ -129,8 +125,10 @@ public class TextExpansionQueryBuilderTests extends AbstractQueryTestCase<TextEx
 
     @Override
     protected void doAssertLuceneQuery(TextExpansionQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
-        assertThat(query, instanceOf(BooleanQuery.class));
-        BooleanQuery booleanQuery = (BooleanQuery) query;
+        assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
+        var sparseQuery = (SparseVectorQueryWrapper) query;
+        assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
+        BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
         assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1);
         assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS));
 

+ 9 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java

@@ -271,8 +271,11 @@ public class WeightedTokensQueryBuilderTests extends AbstractQueryTestCase<Weigh
     }
 
     private void assertCorrectLuceneQuery(String name, Query query, List<String> expectedFeatureFields) {
-        assertTrue(query instanceof BooleanQuery);
-        List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
+        assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
+        var sparseQuery = (SparseVectorQueryWrapper) query;
+        assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
+        BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
+        List<BooleanClause> booleanClauses = booleanQuery.clauses();
         assertEquals(
             name + " had " + booleanClauses.size() + " clauses, expected " + expectedFeatureFields.size(),
             expectedFeatureFields.size(),
@@ -343,8 +346,10 @@ public class WeightedTokensQueryBuilderTests extends AbstractQueryTestCase<Weigh
 
     @Override
     protected void doAssertLuceneQuery(WeightedTokensQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
-        assertThat(query, instanceOf(BooleanQuery.class));
-        BooleanQuery booleanQuery = (BooleanQuery) query;
+        assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
+        var sparseQuery = (SparseVectorQueryWrapper) query;
+        assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
+        BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
         assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1);
         assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS));
 

+ 0 - 19
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -48,7 +48,6 @@ import org.elasticsearch.env.Environment;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.analysis.CharFilterFactory;
 import org.elasticsearch.index.analysis.TokenizerFactory;
-import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.indices.AssociatedIndexDescriptor;
 import org.elasticsearch.indices.SystemIndexDescriptor;
 import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
@@ -376,8 +375,6 @@ import org.elasticsearch.xpack.ml.process.MlControllerHolder;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.process.NativeController;
 import org.elasticsearch.xpack.ml.process.NativeStorageProvider;
-import org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder;
-import org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder;
 import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction;
 import org.elasticsearch.xpack.ml.rest.RestMlInfoAction;
 import org.elasticsearch.xpack.ml.rest.RestMlMemoryAction;
@@ -1764,22 +1761,6 @@ public class MachineLearning extends Plugin
         );
     }
 
-    @Override
-    public List<QuerySpec<?>> getQueries() {
-        return List.of(
-            new QuerySpec<QueryBuilder>(
-                TextExpansionQueryBuilder.NAME,
-                TextExpansionQueryBuilder::new,
-                TextExpansionQueryBuilder::fromXContent
-            ),
-            new QuerySpec<QueryBuilder>(
-                SparseVectorQueryBuilder.NAME,
-                SparseVectorQueryBuilder::new,
-                SparseVectorQueryBuilder::fromXContent
-            )
-        );
-    }
-
     private <T> ContextParser<String, T> checkAggLicense(ContextParser<String, T> realParser, LicensedFeature.Momentary feature) {
         return (parser, name) -> {
             if (feature.check(getLicenseState()) == false) {