|
@@ -20,6 +20,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
|
|
|
import org.elasticsearch.search.rank.RankBuilder;
|
|
|
import org.elasticsearch.search.rank.RankDoc;
|
|
|
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
|
|
+import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource;
|
|
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
|
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
|
|
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
|
@@ -37,7 +38,7 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Objects;
|
|
|
|
|
|
-import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
+import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;
|
|
|
|
|
|
/**
|
|
|
* An rrf retriever is used to represent an rrf rank element, but
|
|
@@ -48,6 +49,7 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
|
|
|
*/
|
|
|
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
|
|
|
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
|
|
|
+ public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");
|
|
|
|
|
|
public static final String NAME = "rrf";
|
|
|
|
|
@@ -57,37 +59,38 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
public static final ParseField QUERY_FIELD = new ParseField("query");
|
|
|
|
|
|
public static final int DEFAULT_RANK_CONSTANT = 60;
|
|
|
+
|
|
|
+ private final float[] weights;
|
|
|
+
|
|
|
@SuppressWarnings("unchecked")
|
|
|
static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
|
|
|
NAME,
|
|
|
false,
|
|
|
args -> {
|
|
|
- List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
|
|
|
+ List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0];
|
|
|
List<String> fields = (List<String>) args[1];
|
|
|
String query = (String) args[2];
|
|
|
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
|
|
|
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
|
|
|
|
|
|
- List<RetrieverSource> innerRetrievers = childRetrievers != null
|
|
|
- ? childRetrievers.stream().map(RetrieverSource::from).toList()
|
|
|
- : List.of();
|
|
|
- return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
|
|
|
+ int n = retrieverComponents.size();
|
|
|
+ List<RetrieverSource> innerRetrievers = new ArrayList<>(n);
|
|
|
+ float[] weights = new float[n];
|
|
|
+ for (int i = 0; i < n; i++) {
|
|
|
+ RRFRetrieverComponent component = retrieverComponents.get(i);
|
|
|
+ innerRetrievers.add(RetrieverSource.from(component.retriever()));
|
|
|
+ weights[i] = component.weight();
|
|
|
+ }
|
|
|
+ return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
|
|
|
}
|
|
|
);
|
|
|
|
|
|
static {
|
|
|
- PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
|
|
|
- p.nextToken();
|
|
|
- String name = p.currentName();
|
|
|
- RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
|
|
|
- c.trackRetrieverUsage(retrieverBuilder.getName());
|
|
|
- p.nextToken();
|
|
|
- return retrieverBuilder;
|
|
|
- }, RETRIEVERS_FIELD);
|
|
|
- PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
|
|
|
- PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
|
|
|
- PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
|
|
|
- PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
|
|
|
+ PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
|
|
|
+ PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
|
|
|
+ PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
|
|
|
RetrieverBuilder.declareBaseParserFields(PARSER);
|
|
|
}
|
|
|
|
|
@@ -103,7 +106,14 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
private final int rankConstant;
|
|
|
|
|
|
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
|
|
|
- this(childRetrievers, null, null, rankWindowSize, rankConstant);
|
|
|
+ this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers));
|
|
|
+ }
|
|
|
+
|
|
|
+ private static float[] createDefaultWeights(List<?> retrievers) {
|
|
|
+ int size = retrievers == null ? 0 : retrievers.size();
|
|
|
+ float[] defaultWeights = new float[size];
|
|
|
+ Arrays.fill(defaultWeights, DEFAULT_WEIGHT);
|
|
|
+ return defaultWeights;
|
|
|
}
|
|
|
|
|
|
public RRFRetrieverBuilder(
|
|
@@ -111,19 +121,31 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
List<String> fields,
|
|
|
String query,
|
|
|
int rankWindowSize,
|
|
|
- int rankConstant
|
|
|
+ int rankConstant,
|
|
|
+ float[] weights
|
|
|
) {
|
|
|
// Use a mutable list for childRetrievers so that we can use addChild
|
|
|
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
|
|
|
this.fields = fields == null ? null : List.copyOf(fields);
|
|
|
this.query = query;
|
|
|
this.rankConstant = rankConstant;
|
|
|
+ Objects.requireNonNull(weights, "weights must not be null");
|
|
|
+ if (weights.length != innerRetrievers.size()) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]"
|
|
|
+ );
|
|
|
+ }
|
|
|
+ this.weights = weights;
|
|
|
}
|
|
|
|
|
|
public int rankConstant() {
|
|
|
return rankConstant;
|
|
|
}
|
|
|
|
|
|
+ public float[] weights() {
|
|
|
+ return weights;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public String getName() {
|
|
|
return NAME;
|
|
@@ -137,6 +159,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
boolean allowPartialSearchResults
|
|
|
) {
|
|
|
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
|
|
|
+
|
|
|
return MultiFieldsInnerRetrieverUtils.validateParams(
|
|
|
innerRetrievers,
|
|
|
fields,
|
|
@@ -151,7 +174,14 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
|
|
|
@Override
|
|
|
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
|
|
|
- RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
|
|
|
+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder(
|
|
|
+ newRetrievers,
|
|
|
+ this.fields,
|
|
|
+ this.query,
|
|
|
+ this.rankWindowSize,
|
|
|
+ this.rankConstant,
|
|
|
+ this.weights
|
|
|
+ );
|
|
|
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
|
|
|
clone.retrieverName = retrieverName;
|
|
|
return clone;
|
|
@@ -183,7 +213,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
|
|
|
// calculate the current rrf score for this document
|
|
|
// later used to sort and covert to a rank
|
|
|
- value.score += 1.0f / (rankConstant + frank);
|
|
|
+ value.score += this.weights[findex] * (1.0f / (rankConstant + frank));
|
|
|
|
|
|
if (explain && value.positions != null && value.scores != null) {
|
|
|
// record the position for each query
|
|
@@ -238,10 +268,14 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
query,
|
|
|
localIndicesMetadata.values(),
|
|
|
r -> {
|
|
|
- List<RetrieverSource> retrievers = r.stream()
|
|
|
- .map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
|
|
|
- .toList();
|
|
|
- return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
|
|
|
+ List<RetrieverSource> retrievers = new ArrayList<>(r.size());
|
|
|
+ float[] weights = new float[r.size()];
|
|
|
+ for (int i = 0; i < r.size(); i++) {
|
|
|
+ var retriever = r.get(i);
|
|
|
+ retrievers.add(retriever.retrieverSource());
|
|
|
+ weights[i] = retriever.weight();
|
|
|
+ }
|
|
|
+ return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
|
|
|
},
|
|
|
w -> {
|
|
|
if (w != 1.0f) {
|
|
@@ -255,7 +289,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
if (fieldsInnerRetrievers.isEmpty() == false) {
|
|
|
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
|
|
|
// (such as dropping the retriever name and min score)
|
|
|
- rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
|
|
|
+ float[] weights = createDefaultWeights(fieldsInnerRetrievers);
|
|
|
+ rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, null, null, rankWindowSize, rankConstant, weights);
|
|
|
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
|
|
|
} else {
|
|
|
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
|
|
@@ -266,29 +301,13 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
return rewritten;
|
|
|
}
|
|
|
|
|
|
- // ---- FOR TESTING XCONTENT PARSING ----
|
|
|
-
|
|
|
- @Override
|
|
|
- public boolean doEquals(Object o) {
|
|
|
- RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
|
|
|
- return super.doEquals(o)
|
|
|
- && Objects.equals(fields, that.fields)
|
|
|
- && Objects.equals(query, that.query)
|
|
|
- && rankConstant == that.rankConstant;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public int doHashCode() {
|
|
|
- return Objects.hash(super.doHashCode(), fields, query, rankConstant);
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
if (innerRetrievers.isEmpty() == false) {
|
|
|
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
|
|
|
-
|
|
|
- for (var entry : innerRetrievers) {
|
|
|
- entry.retriever().toXContent(builder, params);
|
|
|
+ for (int i = 0; i < innerRetrievers.size(); i++) {
|
|
|
+ RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]);
|
|
|
+ component.toXContent(builder, params);
|
|
|
}
|
|
|
builder.endArray();
|
|
|
}
|
|
@@ -307,4 +326,20 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|
|
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
|
|
|
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
|
|
|
}
|
|
|
+
|
|
|
+ // ---- FOR TESTING XCONTENT PARSING ----
|
|
|
+ @Override
|
|
|
+ public boolean doEquals(Object o) {
|
|
|
+ RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
|
|
|
+ return super.doEquals(o)
|
|
|
+ && Objects.equals(fields, that.fields)
|
|
|
+ && Objects.equals(query, that.query)
|
|
|
+ && rankConstant == that.rankConstant
|
|
|
+ && Arrays.equals(weights, that.weights);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int doHashCode() {
|
|
|
+ return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
|
|
|
+ }
|
|
|
}
|