Browse Source

Implement support for weighted rrf (#130658)

* RRFRetrieverComponent added:

* Modified parser, toXcontent and included component in the RetrieverBuilder

* [CI] Auto commit changes from spotless

* Resolved merge conflicts

* Fixed compile issues in tests

* [CI] Auto commit changes from spotless

* trying to resolve parse errros

* wip

* Modified builder

* [CI] Auto commit changes from spotless

* Removed unnecessary code

* Fixed import

* Enhanced tests

* Fixed the failing tests

* Yaml tests were added

* Added cluster features to it

* Fixed spotless

* Update docs/changelog/130658.yaml

* Fixed the relaxed constraints

* Resolving issues

* Resolved PR comments

* removed simplified rrf

* changed the test file back to its original state

* Resolved comments to have ahelper method and the test case to use it

* made parsing robust

* IT test reverted

* Replaced the declareString array parser

* Enforced weights as nonnull

* Fixed the weights null

* Empty weight shouldnt be serialised

* [CI] Auto commit changes from spotless

* removed the hard coding

* Cleanup and optimised the code flow

* Fixed the comments

* [CI] Auto commit changes from spotless

* optimised test

* Added additional test

* addressed the commentS

* Update docs/changelog/130658.yaml

Co-authored-by: Liam Thompson <leemthompo@gmail.com>

* Explicit check for retriever object

* Resolved PR comments

* Fixed the error message

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: Ioana Tagirta <ioana.tagirta@elastic.co>
Co-authored-by: Liam Thompson <leemthompo@gmail.com>
Mridula 2 months ago
parent
commit
fc0ea64fa2

+ 5 - 0
docs/changelog/130658.yaml

@@ -0,0 +1,5 @@
+pr: 130658
+summary: Add support for weighted RRF in retrievers
+area: Relevance
+type: enhancement
+issues: []

+ 2 - 1
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java

@@ -36,7 +36,8 @@ public class RankRRFFeatures implements FeatureSpecification {
             LINEAR_RETRIEVER_L2_NORM,
             LINEAR_RETRIEVER_MINSCORE_FIX,
             LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
-            RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
+            RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
+            RRFRetrieverBuilder.WEIGHTED_SUPPORT
         );
     }
 }

+ 81 - 46
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

@@ -20,6 +20,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rank.RankBuilder;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
+import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
 import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
@@ -37,7 +38,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
-import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
+import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;
 
 /**
  * An rrf retriever is used to represent an rrf rank element, but
@@ -48,6 +49,7 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
  */
 public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
     public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
+    public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");
 
     public static final String NAME = "rrf";
 
@@ -57,37 +59,38 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
     public static final ParseField QUERY_FIELD = new ParseField("query");
 
     public static final int DEFAULT_RANK_CONSTANT = 60;
+
+    private final float[] weights;
+
     @SuppressWarnings("unchecked")
     static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
         NAME,
         false,
         args -> {
-            List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
+            List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0];
             List<String> fields = (List<String>) args[1];
             String query = (String) args[2];
             int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
             int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
 
-            List<RetrieverSource> innerRetrievers = childRetrievers != null
-                ? childRetrievers.stream().map(RetrieverSource::from).toList()
-                : List.of();
-            return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
+            int n = retrieverComponents.size();
+            List<RetrieverSource> innerRetrievers = new ArrayList<>(n);
+            float[] weights = new float[n];
+            for (int i = 0; i < n; i++) {
+                RRFRetrieverComponent component = retrieverComponents.get(i);
+                innerRetrievers.add(RetrieverSource.from(component.retriever()));
+                weights[i] = component.weight();
+            }
+            return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
         }
     );
 
     static {
-        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
-            p.nextToken();
-            String name = p.currentName();
-            RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
-            c.trackRetrieverUsage(retrieverBuilder.getName());
-            p.nextToken();
-            return retrieverBuilder;
-        }, RETRIEVERS_FIELD);
-        PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
-        PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
-        PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
-        PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
+        PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
+        PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
         RetrieverBuilder.declareBaseParserFields(PARSER);
     }
 
@@ -103,7 +106,14 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
     private final int rankConstant;
 
     public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
-        this(childRetrievers, null, null, rankWindowSize, rankConstant);
+        this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers));
+    }
+
+    private static float[] createDefaultWeights(List<?> retrievers) {
+        int size = retrievers == null ? 0 : retrievers.size();
+        float[] defaultWeights = new float[size];
+        Arrays.fill(defaultWeights, DEFAULT_WEIGHT);
+        return defaultWeights;
     }
 
     public RRFRetrieverBuilder(
@@ -111,19 +121,31 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
         List<String> fields,
         String query,
         int rankWindowSize,
-        int rankConstant
+        int rankConstant,
+        float[] weights
     ) {
         // Use a mutable list for childRetrievers so that we can use addChild
         super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
         this.fields = fields == null ? null : List.copyOf(fields);
         this.query = query;
         this.rankConstant = rankConstant;
+        Objects.requireNonNull(weights, "weights must not be null");
+        if (weights.length != innerRetrievers.size()) {
+            throw new IllegalArgumentException(
+                "weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]"
+            );
+        }
+        this.weights = weights;
     }
 
     public int rankConstant() {
         return rankConstant;
     }
 
+    public float[] weights() {
+        return weights;
+    }
+
     @Override
     public String getName() {
         return NAME;
@@ -137,6 +159,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
         boolean allowPartialSearchResults
     ) {
         validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
+
         return MultiFieldsInnerRetrieverUtils.validateParams(
             innerRetrievers,
             fields,
@@ -151,7 +174,14 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
 
     @Override
     protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
-        RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
+        RRFRetrieverBuilder clone = new RRFRetrieverBuilder(
+            newRetrievers,
+            this.fields,
+            this.query,
+            this.rankWindowSize,
+            this.rankConstant,
+            this.weights
+        );
         clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
         clone.retrieverName = retrieverName;
         return clone;
@@ -183,7 +213,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
 
                     // calculate the current rrf score for this document
                     // later used to sort and covert to a rank
-                    value.score += 1.0f / (rankConstant + frank);
+                    value.score += this.weights[findex] * (1.0f / (rankConstant + frank));
 
                     if (explain && value.positions != null && value.scores != null) {
                         // record the position for each query
@@ -238,10 +268,14 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
                 query,
                 localIndicesMetadata.values(),
                 r -> {
-                    List<RetrieverSource> retrievers = r.stream()
-                        .map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
-                        .toList();
-                    return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
+                    List<RetrieverSource> retrievers = new ArrayList<>(r.size());
+                    float[] weights = new float[r.size()];
+                    for (int i = 0; i < r.size(); i++) {
+                        var retriever = r.get(i);
+                        retrievers.add(retriever.retrieverSource());
+                        weights[i] = retriever.weight();
+                    }
+                    return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
                 },
                 w -> {
                     if (w != 1.0f) {
@@ -255,7 +289,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
             if (fieldsInnerRetrievers.isEmpty() == false) {
                 // TODO: This is a incomplete solution as it does not address other incomplete copy issues
                 // (such as dropping the retriever name and min score)
-                rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
+                float[] weights = createDefaultWeights(fieldsInnerRetrievers);
+                rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, null, null, rankWindowSize, rankConstant, weights);
                 rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
             } else {
                 // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
@@ -266,29 +301,13 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
         return rewritten;
     }
 
-    // ---- FOR TESTING XCONTENT PARSING ----
-
-    @Override
-    public boolean doEquals(Object o) {
-        RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
-        return super.doEquals(o)
-            && Objects.equals(fields, that.fields)
-            && Objects.equals(query, that.query)
-            && rankConstant == that.rankConstant;
-    }
-
-    @Override
-    public int doHashCode() {
-        return Objects.hash(super.doHashCode(), fields, query, rankConstant);
-    }
-
     @Override
     public void doToXContent(XContentBuilder builder, Params params) throws IOException {
         if (innerRetrievers.isEmpty() == false) {
             builder.startArray(RETRIEVERS_FIELD.getPreferredName());
-
-            for (var entry : innerRetrievers) {
-                entry.retriever().toXContent(builder, params);
+            for (int i = 0; i < innerRetrievers.size(); i++) {
+                RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]);
+                component.toXContent(builder, params);
             }
             builder.endArray();
         }
@@ -307,4 +326,20 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
         builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
         builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
     }
+
+    // ---- FOR TESTING XCONTENT PARSING ----
+    @Override
+    public boolean doEquals(Object o) {
+        RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
+        return super.doEquals(o)
+            && Objects.equals(fields, that.fields)
+            && Objects.equals(query, that.query)
+            && rankConstant == that.rankConstant
+            && Arrays.equals(weights, that.weights);
+    }
+
+    @Override
+    public int doHashCode() {
+        return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
+    }
 }

+ 124 - 0
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverComponent.java

@@ -0,0 +1,124 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.rank.rrf;
+
+import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.search.retriever.RetrieverBuilder;
+import org.elasticsearch.search.retriever.RetrieverParserContext;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class RRFRetrieverComponent implements ToXContentObject {
+
+    public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
+    public static final ParseField WEIGHT_FIELD = new ParseField("weight");
+    static final float DEFAULT_WEIGHT = 1f;
+
+    final RetrieverBuilder retriever;
+    final float weight;
+
+    public RRFRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight) {
+        this.retriever = Objects.requireNonNull(retrieverBuilder, "retrieverBuilder must not be null");
+        this.weight = weight == null ? DEFAULT_WEIGHT : weight;
+        if (this.weight < 0) {
+            throw new IllegalArgumentException("[weight] must be non-negative, found [" + this.weight + "]");
+        }
+    }
+
+    public RetrieverBuilder retriever() {
+        return retriever;
+    }
+
+    public float weight() {
+        return weight;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException {
+        builder.startObject();
+        builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
+        builder.field(WEIGHT_FIELD.getPreferredName(), weight);
+        builder.endObject();
+        return builder;
+    }
+
+    public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
+        if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
+            throw new ParsingException(parser.getTokenLocation(), "expected object but found [{}]", parser.currentToken());
+        }
+
+        // Peek at the first field to determine the format
+        XContentParser.Token token = parser.nextToken();
+        if (token == XContentParser.Token.END_OBJECT) {
+            throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever");
+        }
+        if (token != XContentParser.Token.FIELD_NAME) {
+            throw new ParsingException(parser.getTokenLocation(), "expected field name but found [{}]", token);
+        }
+
+        String firstFieldName = parser.currentName();
+
+        // Check if this is a structured component (starts with "retriever" or "weight")
+        if (RETRIEVER_FIELD.match(firstFieldName, parser.getDeprecationHandler())
+            || WEIGHT_FIELD.match(firstFieldName, parser.getDeprecationHandler())) {
+            // This is a structured component - parse manually
+            RetrieverBuilder retriever = null;
+            Float weight = null;
+
+            do {
+                String fieldName = parser.currentName();
+                if (RETRIEVER_FIELD.match(fieldName, parser.getDeprecationHandler())) {
+                    if (retriever != null) {
+                        throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified");
+                    }
+                    parser.nextToken();
+                    if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
+                        throw new ParsingException(parser.getTokenLocation(), "retriever must be an object");
+                    }
+                    parser.nextToken();
+                    String retrieverType = parser.currentName();
+                    retriever = parser.namedObject(RetrieverBuilder.class, retrieverType, context);
+                    context.trackRetrieverUsage(retriever.getName());
+                    parser.nextToken();
+                } else if (WEIGHT_FIELD.match(fieldName, parser.getDeprecationHandler())) {
+                    if (weight != null) {
+                        throw new ParsingException(parser.getTokenLocation(), "[weight] field can only be specified once");
+                    }
+                    parser.nextToken();
+                    weight = parser.floatValue();
+                } else {
+                    throw new ParsingException(
+                        parser.getTokenLocation(),
+                        "unknown field [{}], expected [{}] or [{}]",
+                        fieldName,
+                        RETRIEVER_FIELD.getPreferredName(),
+                        WEIGHT_FIELD.getPreferredName()
+                    );
+                }
+            } while (parser.nextToken() == XContentParser.Token.FIELD_NAME);
+
+            if (retriever == null) {
+                throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever");
+            }
+
+            return new RRFRetrieverComponent(retriever, weight);
+        } else {
+            RetrieverBuilder retriever = parser.namedObject(RetrieverBuilder.class, firstFieldName, context);
+            context.trackRetrieverUsage(retriever.getName());
+            if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
+                throw new ParsingException(parser.getTokenLocation(), "unknown field [{}] after retriever", parser.currentName());
+            }
+            return new RRFRetrieverComponent(retriever, DEFAULT_WEIGHT);
+        }
+    }
+}

+ 245 - 17
x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

@@ -27,6 +27,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 
@@ -56,12 +57,15 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RR
 
         int retrieverCount = randomIntBetween(2, 50);
         List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
+        float[] weights = new float[retrieverCount];
+        int i = 0;
         while (retrieverCount > 0) {
             innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
+            weights[i++] = randomFloat();
             --retrieverCount;
         }
 
-        return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
+        return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
     }
 
     @Override
@@ -89,7 +93,7 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RR
             new NamedXContentRegistry.Entry(
                 RetrieverBuilder.class,
                 TestRetrieverBuilder.TEST_SPEC.getName(),
-                (p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c),
+                (p, c) -> TestRetrieverBuilder.fromXContent(p, (RetrieverParserContext) c),
                 TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion()
             )
         );
@@ -103,6 +107,28 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RR
         return new NamedXContentRegistry(entries);
     }
 
+    private void checkRRFRetrieverParsing(String restContent) throws IOException {
+        SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
+        try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
+            SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
+            assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class));
+            RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever();
+            assertThat(parsed.minScore(), equalTo(20f));
+            assertThat(parsed.retrieverName(), equalTo("foo_rrf"));
+            try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
+                SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
+                    parseSerialized,
+                    true,
+                    searchUsageHolder,
+                    nf -> true
+                );
+                assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class));
+                RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever();
+                assertThat(parsed, equalTo(deserialized));
+            }
+        }
+    }
+
     public void testRRFRetrieverParsing() throws IOException {
         String restContent = """
             {
@@ -130,24 +156,226 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RR
               }
             }
             """;
+        checkRRFRetrieverParsing(restContent);
+    }
+
+    public void testRRFRetrieverParsingWithWeights() throws IOException {
+        String restContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "retriever": {
+                        "test": {
+                          "value": "first"
+                        }
+                      },
+                      "weight": 2.0
+                    },
+                    {
+                      "retriever": {
+                        "test": {
+                          "value": "second"
+                        }
+                      },
+                      "weight": 0.5
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+        checkRRFRetrieverParsing(restContent);
+    }
+
+    public void testRRFRetrieverParsingWithMixedWeights() throws IOException {
+        String restContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "test": {
+                        "value": "no_weight"
+                      }
+                    },
+                    {
+                      "retriever": {
+                        "test": {
+                          "value": "with_weight"
+                        }
+                      },
+                      "weight": 1.5
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+        checkRRFRetrieverParsing(restContent);
+    }
+
+    public void testRRFRetrieverParsingWithDefaultWeights() throws IOException {
+        String restContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "test": {
+                        "value": "first"
+                      }
+                    },
+                    {
+                      "test": {
+                        "value": "second"
+                      }
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+        checkRRFRetrieverParsing(restContent);
+    }
+
+    public void testRRFRetrieverComponentErrorCases() throws IOException {
+        // Test case 1: Multiple retrievers in same component
+        String multipleRetrieversContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "retriever": { "test": { "value": "first" } },
+                      "standard": { "query": { "match_all": {} } }
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+
+        expectParsingException(multipleRetrieversContent, "unknown field [standard], expected [retriever] or [weight]");
+
+        // Test case 2: Weight without retriever
+        String weightOnlyContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "weight": 2.0
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+
+        expectParsingException(weightOnlyContent, "retriever component must contain a retriever");
+
+        // Test case 3: Empty retriever component
+        String emptyComponentContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {}
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+
+        expectParsingException(emptyComponentContent, "retriever component must contain a retriever");
+
+        // Test case 4: Negative weight
+        String negativeWeightContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "retriever": { "test": { "value": "test" } },
+                      "weight": -1.0
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+
+        expectParsingException(negativeWeightContent, "[weight] must be non-negative");
+
+        // Test case 5: Retriever as non-object
+        String retrieverAsStringContent = """
+            {
+              "retriever": {
+                "rrf": {
+                  "retrievers": [
+                    {
+                      "retriever": "not_an_object"
+                    }
+                  ],
+                  "rank_window_size": 100,
+                  "rank_constant": 10,
+                  "min_score": 20.0,
+                  "_name": "foo_rrf"
+                }
+              }
+            }
+            """;
+
+        expectParsingException(retrieverAsStringContent, "retriever must be an object");
+    }
+
+    private void expectParsingException(String restContent, String expectedMessageFragment) throws IOException {
         SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
         try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
-            SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
-            assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class));
-            RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever();
-            assertThat(parsed.minScore(), equalTo(20f));
-            assertThat(parsed.retrieverName(), equalTo("foo_rrf"));
-            try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
-                SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
-                    parseSerialized,
-                    true,
-                    searchUsageHolder,
-                    nf -> true
-                );
-                assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class));
-                RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever();
-                assertThat(parsed, equalTo(deserialized));
+            Exception exception = expectThrows(Exception.class, () -> {
+                new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
+            });
+
+            String message = exception.getMessage();
+            if (exception.getCause() != null) {
+                message = exception.getCause().getMessage();
             }
+
+            assertThat(
+                "Expected error message to contain: " + expectedMessageFragment + ", but got: " + message,
+                message,
+                containsString(expectedMessageFragment)
+            );
         }
     }
 }

+ 67 - 5
x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java

@@ -39,8 +39,10 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiConsumer;
 
 import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
+import static org.hamcrest.Matchers.instanceOf;
 
 /** Tests for the rrf retriever. */
 public class RRFRetrieverBuilderTests extends ESTestCase {
@@ -84,6 +86,61 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
         }
     }
 
+    public void testRRFRetrieverParsingSyntax() throws IOException {
+        BiConsumer<String, float[]> testCase = (json, expectedWeights) -> {
+            try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
+                SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true);
+                assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class));
+                RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever();
+                assertArrayEquals(expectedWeights, rrf.weights(), 0.001f);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        };
+
+        String legacyJson = """
+            {
+              "retriever": {
+                "rrf_nl": {
+                  "retrievers": [
+                    { "standard": { "query": { "match_all": {} } } },
+                    { "standard": { "query": { "match_all": {} } } }
+                  ]
+                }
+              }
+            }
+            """;
+        testCase.accept(legacyJson, new float[] { 1.0f, 1.0f });
+
+        String weightedJson = """
+            {
+              "retriever": {
+                "rrf_nl": {
+                  "retrievers": [
+                    { "retriever": { "standard": { "query": { "match_all": {} } } }, "weight": 2.5 },
+                    { "retriever": { "standard": { "query": { "match_all": {} } } }, "weight": 0.5 }
+                  ]
+                }
+              }
+            }
+            """;
+        testCase.accept(weightedJson, new float[] { 2.5f, 0.5f });
+
+        String mixedJson = """
+            {
+              "retriever": {
+                "rrf_nl": {
+                  "retrievers": [
+                    { "standard": { "query": { "match_all": {} } } },
+                    { "retriever": { "standard": { "query": { "match_all": {} } } }, "weight": 0.6 }
+                  ]
+                }
+              }
+            }
+            """;
+        testCase.accept(mixedJson, new float[] { 1.0f, 0.6f });
+    }
+
     public void testMultiFieldsParamsRewrite() {
         final String indexName = "test-index";
         final List<String> testInferenceFields = List.of("semantic_field_1", "semantic_field_2");
@@ -103,7 +160,8 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
             "foo",
             DEFAULT_RANK_WINDOW_SIZE,
-            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
+            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
+            new float[0]
         );
         assertMultiFieldsParamsRewrite(
             rrfRetrieverBuilder,
@@ -119,7 +177,8 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
             "foo2",
             DEFAULT_RANK_WINDOW_SIZE * 2,
-            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT / 2
+            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT / 2,
+            new float[0]
         );
         assertMultiFieldsParamsRewrite(
             rrfRetrieverBuilder,
@@ -135,7 +194,8 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             List.of("field_*", "*_field_1"),
             "bar",
             DEFAULT_RANK_WINDOW_SIZE,
-            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
+            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
+            new float[0]
         );
         assertMultiFieldsParamsRewrite(
             rrfRetrieverBuilder,
@@ -151,7 +211,8 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             List.of("*"),
             "baz",
             DEFAULT_RANK_WINDOW_SIZE,
-            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
+            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
+            new float[0]
         );
         assertMultiFieldsParamsRewrite(
             rrfRetrieverBuilder,
@@ -182,7 +243,8 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             null,
             "foo",
             DEFAULT_RANK_WINDOW_SIZE,
-            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
+            RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
+            new float[0]
         );
 
         IllegalArgumentException iae = expectThrows(

+ 139 - 0
x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/320_rrf_weighted_retriever.yml

@@ -0,0 +1,139 @@
+setup:
+  - requires:
+      cluster_features: [ "rrf_retriever.weighted_support" ]
+      reason: "RRF retriever Weighted support"
+      test_runner_features: [ "contains", "close_to" ]
+  - do:
+      indices.create:
+        index: restaurants
+        body:
+          mappings:
+            properties:
+              name: { type: keyword }
+              description: { type: text }
+              city: { type: keyword }
+              region: { type: keyword }
+              vector: { type: dense_vector, dims: 3 }
+  - do:
+      index:
+        index: restaurants
+        id: "1"
+        body: { name: "Pizza Palace", description: "Best pizza in town", city: "Vienna", region: "Austria", vector: [10,22,77] }
+  - do:
+      index:
+        index: restaurants
+        id: "2"
+        body: { name: "Burger House", description: "Juicy burgers", city: "Graz", region: "Austria", vector: [15,25,70] }
+  - do:
+      index:
+        index: restaurants
+        id: "3"
+        body: { name: "Sushi World", description: "Fresh sushi", city: "Linz", region: "Austria", vector: [11,24,75] }
+  - do:
+      indices.refresh: { index: restaurants }
+
+---
+"Weighted RRF retriever returns correct results":
+  - do:
+      search:
+        index: restaurants
+        body:
+          retriever:
+            rrf:
+              retrievers:
+                - retriever:
+                    standard:
+                      query:
+                        multi_match:
+                          query: "Austria"
+                          fields: ["city", "region"]
+                  weight: 0.3
+                - retriever:
+                    standard:
+                      query:
+                        match:
+                          description: "pizza"
+                  weight: 0.7
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "1" }
+
+---
+"Weighted RRF retriever allows optional weight field":
+  - do:
+      search:
+        index: restaurants
+        body:
+          retriever:
+            rrf:
+              retrievers:
+                - standard:
+                    query:
+                      multi_match:
+                        query: "Austria"
+                        fields: ["city", "region"]
+                - retriever:
+                    standard:
+                      query:
+                        match:
+                          description: "pizza"
+                  weight: 0.7
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "1" }
+
+---
+"Weighted RRF retriever changes result order":
+  - do:
+      search:
+        index: restaurants
+        body:
+          retriever:
+            rrf:
+              retrievers:
+                - retriever:
+                    standard:
+                      query:
+                        match:
+                          description: "pizza"
+                  weight: 0.1
+                - retriever:
+                    standard:
+                      query:
+                        match:
+                          description: "burgers"
+                  weight: 0.9
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "2" }
+  - match: { hits.hits.1._id: "1" }
+  # Document 2: matches "burgers" with weight 0.9
+  # RRF score = 1/(60+1) * 0.9 = 0.01475
+  - close_to: {hits.hits.0._score: {value:  0.01475, error: 0.0001}}
+  # Document 1: matches "pizza" with weight 0.1
+  # RRF score = 1/(60+1) * 0.1 = 0.00164
+  - close_to: {hits.hits.1._score: {value:  0.00164, error: 0.0001}}
+
+---
+"Weighted RRF retriever errors on negative weight":
+  - do:
+      catch: bad_request
+      search:
+        index: restaurants
+        body:
+          retriever:
+            rrf:
+              retrievers:
+                - retriever:
+                    standard:
+                      query:
+                        multi_match:
+                          query: "Austria"
+                          fields: ["city", "region"]
+                  weight: -0.5
+                - retriever:
+                    standard:
+                      query:
+                        match:
+                          description: "pizza"
+                  weight: 0.7
+  - match: { error.type: "x_content_parse_exception" }
+  - contains: { error.caused_by.reason: "[weight] must be non-negative, found [-0.5]" }
+