|
@@ -8,21 +8,28 @@
|
|
|
package org.elasticsearch.xpack.rank.rrf;
|
|
|
|
|
|
import org.apache.lucene.search.ScoreDoc;
|
|
|
+import org.elasticsearch.action.ActionRequestValidationException;
|
|
|
+import org.elasticsearch.action.ResolvedIndices;
|
|
|
import org.elasticsearch.common.ParsingException;
|
|
|
import org.elasticsearch.common.util.Maps;
|
|
|
import org.elasticsearch.features.NodeFeature;
|
|
|
+import org.elasticsearch.index.query.MatchNoneQueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilder;
|
|
|
+import org.elasticsearch.index.query.QueryRewriteContext;
|
|
|
import org.elasticsearch.license.LicenseUtils;
|
|
|
+import org.elasticsearch.search.builder.SearchSourceBuilder;
|
|
|
import org.elasticsearch.search.rank.RankBuilder;
|
|
|
import org.elasticsearch.search.rank.RankDoc;
|
|
|
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
|
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
|
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
|
|
+import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
|
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.xcontent.ParseField;
|
|
|
import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.xcontent.XContentParser;
|
|
|
import org.elasticsearch.xpack.core.XPackPlugin;
|
|
|
+import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
@@ -31,7 +38,6 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Objects;
|
|
|
|
|
|
-import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
|
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
|
|
|
/**
|
|
@@ -42,6 +48,7 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
|
|
|
* formula.
|
|
|
*/
|
|
|
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
|
|
|
+ public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
|
|
|
|
|
|
public static final String NAME = "rrf";
|
|
|
public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature("rrf_retriever_supported", true);
|
|
@@ -49,6 +56,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
|
|
|
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
|
|
|
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
|
|
|
+ public static final ParseField FIELDS_FIELD = new ParseField("fields");
|
|
|
+ public static final ParseField QUERY_FIELD = new ParseField("query");
|
|
|
|
|
|
public static final int DEFAULT_RANK_CONSTANT = 60;
|
|
|
@SuppressWarnings("unchecked")
|
|
@@ -57,15 +66,20 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
false,
|
|
|
args -> {
|
|
|
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
|
|
|
- List<RetrieverSource> innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList();
|
|
|
- int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
|
|
|
- int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2];
|
|
|
- return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant);
|
|
|
+ List<String> fields = (List<String>) args[1];
|
|
|
+ String query = (String) args[2];
|
|
|
+ int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
|
|
|
+ int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
|
|
|
+
|
|
|
+ List<RetrieverSource> innerRetrievers = childRetrievers != null
|
|
|
+ ? childRetrievers.stream().map(RetrieverSource::from).toList()
|
|
|
+ : List.of();
|
|
|
+ return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
|
|
|
}
|
|
|
);
|
|
|
|
|
|
static {
|
|
|
- PARSER.declareObjectArray(constructorArg(), (p, c) -> {
|
|
|
+ PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
|
|
|
p.nextToken();
|
|
|
String name = p.currentName();
|
|
|
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
|
|
@@ -73,6 +87,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
p.nextToken();
|
|
|
return retrieverBuilder;
|
|
|
}, RETRIEVERS_FIELD);
|
|
|
+ PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
|
|
|
+ PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
|
|
|
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
|
|
|
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
|
|
|
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
|
|
@@ -91,25 +107,60 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
return PARSER.apply(parser, context);
|
|
|
}
|
|
|
|
|
|
+ private final List<String> fields;
|
|
|
+ private final String query;
|
|
|
private final int rankConstant;
|
|
|
|
|
|
- public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
|
|
|
- this(new ArrayList<>(), rankWindowSize, rankConstant);
|
|
|
+ public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
|
|
|
+ this(childRetrievers, null, null, rankWindowSize, rankConstant);
|
|
|
}
|
|
|
|
|
|
- RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
|
|
|
- super(childRetrievers, rankWindowSize);
|
|
|
+ public RRFRetrieverBuilder(
|
|
|
+ List<RetrieverSource> childRetrievers,
|
|
|
+ List<String> fields,
|
|
|
+ String query,
|
|
|
+ int rankWindowSize,
|
|
|
+ int rankConstant
|
|
|
+ ) {
|
|
|
+ // Use a mutable list for childRetrievers so that we can use addChild
|
|
|
+ super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
|
|
|
+ this.fields = fields == null ? List.of() : List.copyOf(fields);
|
|
|
+ this.query = query;
|
|
|
this.rankConstant = rankConstant;
|
|
|
}
|
|
|
|
|
|
+ public int rankConstant() {
|
|
|
+ return rankConstant;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public String getName() {
|
|
|
return NAME;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public ActionRequestValidationException validate(
|
|
|
+ SearchSourceBuilder source,
|
|
|
+ ActionRequestValidationException validationException,
|
|
|
+ boolean isScroll,
|
|
|
+ boolean allowPartialSearchResults
|
|
|
+ ) {
|
|
|
+ validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
|
|
+ return MultiFieldsInnerRetrieverUtils.validateParams(
|
|
|
+ innerRetrievers,
|
|
|
+ fields,
|
|
|
+ query,
|
|
|
+ getName(),
|
|
|
+ RETRIEVERS_FIELD.getPreferredName(),
|
|
|
+ FIELDS_FIELD.getPreferredName(),
|
|
|
+ QUERY_FIELD.getPreferredName(),
|
|
|
+ validationException
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
|
|
- RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
|
|
|
+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
|
|
|
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
|
|
clone.retrieverName = retrieverName;
|
|
|
return clone;
|
|
@@ -172,17 +223,72 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
return topResults;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
|
|
|
+ RetrieverBuilder rewritten = this;
|
|
|
+
|
|
|
+ ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
|
|
|
+ if (resolvedIndices != null && query != null) {
|
|
|
+ // TODO: Refactor duplicate code
|
|
|
+ // Using the multi-fields query format
|
|
|
+ var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
|
|
|
+ if (localIndicesMetadata.size() > 1) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices"
|
|
|
+ );
|
|
|
+ } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices"
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
|
|
|
+ fields,
|
|
|
+ query,
|
|
|
+ localIndicesMetadata.values(),
|
|
|
+ r -> {
|
|
|
+ List<RetrieverSource> retrievers = r.stream()
|
|
|
+ .map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
|
|
|
+ .toList();
|
|
|
+ return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
|
|
|
+ },
|
|
|
+ w -> {
|
|
|
+ if (w != 1.0f) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ).stream().map(RetrieverSource::from).toList();
|
|
|
+
|
|
|
+ if (fieldsInnerRetrievers.isEmpty() == false) {
|
|
|
+ // TODO: This is a incomplete solution as it does not address other incomplete copy issues
|
|
|
+ // (such as dropping the retriever name and min score)
|
|
|
+ rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
|
|
|
+ rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
|
|
|
+ } else {
|
|
|
+ // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
|
|
|
+ rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return rewritten;
|
|
|
+ }
|
|
|
+
|
|
|
// ---- FOR TESTING XCONTENT PARSING ----
|
|
|
|
|
|
@Override
|
|
|
public boolean doEquals(Object o) {
|
|
|
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
|
|
|
- return super.doEquals(o) && rankConstant == that.rankConstant;
|
|
|
+ return super.doEquals(o)
|
|
|
+ && Objects.equals(fields, that.fields)
|
|
|
+ && Objects.equals(query, that.query)
|
|
|
+ && rankConstant == that.rankConstant;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int doHashCode() {
|
|
|
- return Objects.hash(super.doHashCode(), rankConstant);
|
|
|
+ return Objects.hash(super.doHashCode(), fields, query, rankConstant);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -196,6 +302,17 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
builder.endArray();
|
|
|
}
|
|
|
|
|
|
+ if (fields.isEmpty() == false) {
|
|
|
+ builder.startArray(FIELDS_FIELD.getPreferredName());
|
|
|
+ for (String field : fields) {
|
|
|
+ builder.value(field);
|
|
|
+ }
|
|
|
+ builder.endArray();
|
|
|
+ }
|
|
|
+ if (query != null) {
|
|
|
+ builder.field(QUERY_FIELD.getPreferredName(), query);
|
|
|
+ }
|
|
|
+
|
|
|
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
|
|
|
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
|
|
|
}
|