|
@@ -27,10 +27,12 @@ import org.elasticsearch.index.mapper.NestedObjectMapper;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
|
|
|
import org.elasticsearch.index.query.AbstractQueryBuilder;
|
|
|
+import org.elasticsearch.index.query.BoolQueryBuilder;
|
|
|
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryRewriteContext;
|
|
|
import org.elasticsearch.index.query.SearchExecutionContext;
|
|
|
+import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder;
|
|
|
import org.elasticsearch.index.search.NestedHelper;
|
|
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.xcontent.ObjectParser;
|
|
@@ -454,9 +456,6 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
vectorSimilarity
|
|
|
).boost(boost).queryName(queryName).addFilterQueries(filterQueries);
|
|
|
}
|
|
|
- if (ctx.convertToInnerHitsRewriteContext() != null) {
|
|
|
- return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName);
|
|
|
- }
|
|
|
boolean changed = false;
|
|
|
List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
|
|
|
for (QueryBuilder query : filterQueries) {
|
|
@@ -481,6 +480,22 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
vectorSimilarity
|
|
|
).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries);
|
|
|
}
|
|
|
+ if (ctx.convertToInnerHitsRewriteContext() != null) {
|
|
|
+ QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity);
|
|
|
+ if (filterQueries.isEmpty()) {
|
|
|
+ return exactKnnQuery;
|
|
|
+ } else {
|
|
|
+ BoolQueryBuilder boolQuery = new BoolQueryBuilder();
|
|
|
+ boolQuery.must(exactKnnQuery);
|
|
|
+ for (QueryBuilder filter : this.filterQueries) {
|
|
|
+ // filter can be both over parents or nested docs, so add them as should clauses to a filter
|
|
|
+ BoolQueryBuilder adjustedFilter = new BoolQueryBuilder().should(filter)
|
|
|
+ .should(new ToChildBlockJoinQueryBuilder(filter));
|
|
|
+ boolQuery.filter(adjustedFilter);
|
|
|
+ }
|
|
|
+ return boolQuery;
|
|
|
+ }
|
|
|
+ }
|
|
|
return this;
|
|
|
}
|
|
|
|
|
@@ -500,29 +515,27 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
if (fieldType == null) {
|
|
|
return new MatchNoDocsQuery();
|
|
|
}
|
|
|
-
|
|
|
if (fieldType instanceof DenseVectorFieldType == false) {
|
|
|
throw new IllegalArgumentException(
|
|
|
"[" + NAME + "] queries are only supported on [" + DenseVectorFieldMapper.CONTENT_TYPE + "] fields"
|
|
|
);
|
|
|
}
|
|
|
+ DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
|
|
|
|
|
|
- BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
|
|
+ List<Query> filtersInitial = new ArrayList<>(filterQueries.size());
|
|
|
for (QueryBuilder query : this.filterQueries) {
|
|
|
- builder.add(query.toQuery(context), BooleanClause.Occur.FILTER);
|
|
|
+ filtersInitial.add(query.toQuery(context));
|
|
|
}
|
|
|
if (context.getAliasFilter() != null) {
|
|
|
- builder.add(context.getAliasFilter().toQuery(context), BooleanClause.Occur.FILTER);
|
|
|
+ filtersInitial.add(context.getAliasFilter().toQuery(context));
|
|
|
}
|
|
|
- BooleanQuery booleanQuery = builder.build();
|
|
|
- Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
|
|
|
|
|
|
- DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
|
|
|
String parentPath = context.nestedLookup().getNestedParent(fieldName);
|
|
|
- Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();
|
|
|
-
|
|
|
BitSetProducer parentBitSet = null;
|
|
|
- if (parentPath != null) {
|
|
|
+ Query filterQuery;
|
|
|
+ if (parentPath == null) {
|
|
|
+ filterQuery = buildFilterQuery(filtersInitial);
|
|
|
+ } else {
|
|
|
final Query parentFilter;
|
|
|
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
|
|
|
if (originalObjectMapper != null) {
|
|
@@ -541,19 +554,23 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated());
|
|
|
}
|
|
|
parentBitSet = context.bitsetFilter(parentFilter);
|
|
|
- if (filterQuery != null) {
|
|
|
- // We treat the provided filter as a filter over PARENT documents, so if it might match nested documents
|
|
|
- // we need to adjust it.
|
|
|
- if (NestedHelper.mightMatchNestedDocs(filterQuery, context)) {
|
|
|
- // Ensure that the query only returns parent documents matching `filterQuery`
|
|
|
- filterQuery = Queries.filtered(filterQuery, parentFilter);
|
|
|
+ List<Query> filterAdjusted = new ArrayList<>(filtersInitial.size());
|
|
|
+ for (Query f : filtersInitial) {
|
|
|
+ // If filter matches non-nested docs, we assume this is a filter over parents docs,
|
|
|
+ // so we will modify it accordingly: matching parents docs with join to its child docs
|
|
|
+ if (NestedHelper.mightMatchNonNestedDocs(f, parentPath, context)) {
|
|
|
+ // Ensure that the query only returns parent documents matching filter
|
|
|
+ f = Queries.filtered(f, parentFilter);
|
|
|
+ f = new ToChildBlockJoinQuery(f, parentBitSet);
|
|
|
}
|
|
|
- // Now join the filterQuery & parentFilter to provide the matching blocks of children
|
|
|
- filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
|
|
|
+ filterAdjusted.add(f);
|
|
|
}
|
|
|
+ filterQuery = buildFilterQuery(filterAdjusted);
|
|
|
}
|
|
|
+
|
|
|
DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic();
|
|
|
boolean hnswEarlyTermination = context.getIndexSettings().getHnswEarlyTermination();
|
|
|
+ Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();
|
|
|
return vectorFieldType.createKnnQuery(
|
|
|
queryVector,
|
|
|
k,
|
|
@@ -567,6 +584,16 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ private static Query buildFilterQuery(List<Query> filters) {
|
|
|
+ BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
|
|
+ for (Query f : filters) {
|
|
|
+ builder.add(f, BooleanClause.Occur.FILTER);
|
|
|
+ }
|
|
|
+ BooleanQuery booleanQuery = builder.build();
|
|
|
+ Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
|
|
|
+ return filterQuery;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
protected int doHashCode() {
|
|
|
return Objects.hash(
|