فهرست منبع

Updating toXContent implementation for retrievers (#114017) (#114272)

Panagiotis Bailis 1 سال پیش
والد
کامیت
309d234bf0

+ 8 - 0
server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java

@@ -251,11 +251,19 @@ public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>,
     @Override
     public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
         builder.startObject();
+        builder.startObject(getName());
         if (preFilterQueryBuilders.isEmpty() == false) {
             builder.field(PRE_FILTER_FIELD.getPreferredName(), preFilterQueryBuilders);
         }
+        if (minScore != null) {
+            builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
+        }
+        if (retrieverName != null) {
+            builder.field(NAME_FIELD.getPreferredName(), retrieverName);
+        }
         doToXContent(builder, params);
         builder.endObject();
+        builder.endObject();
 
         return builder;
     }

+ 71 - 0
server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java

@@ -41,6 +41,8 @@ import org.elasticsearch.search.collapse.CollapseBuilder;
 import org.elasticsearch.search.collapse.CollapseBuilderTests;
 import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
 import org.elasticsearch.search.rescore.QueryRescorerBuilder;
+import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
+import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
 import org.elasticsearch.search.slice.SliceBuilder;
 import org.elasticsearch.search.sort.FieldSortBuilder;
 import org.elasticsearch.search.sort.ScoreSortBuilder;
@@ -600,6 +602,75 @@ public class SearchSourceBuilderTests extends AbstractSearchTestCase {
         }
     }
 
+    public void testStandardRetrieverParsing() throws IOException {
+        String restContent = "{"
+            + "  \"retriever\": {"
+            + "    \"standard\": {"
+            + "      \"query\": {"
+            + "        \"match_all\": {}"
+            + "      },"
+            + "      \"min_score\": 10,"
+            + "      \"_name\": \"foo_standard\""
+            + "    }"
+            + "  }"
+            + "}";
+        SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
+        try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
+            SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
+            assertThat(source.retriever(), instanceOf(StandardRetrieverBuilder.class));
+            StandardRetrieverBuilder parsed = (StandardRetrieverBuilder) source.retriever();
+            assertThat(parsed.minScore(), equalTo(10f));
+            assertThat(parsed.retrieverName(), equalTo("foo_standard"));
+            try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
+                SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
+                    parseSerialized,
+                    true,
+                    searchUsageHolder,
+                    nf -> true
+                );
+                assertThat(deserializedSource.retriever(), instanceOf(StandardRetrieverBuilder.class));
+                StandardRetrieverBuilder deserialized = (StandardRetrieverBuilder) source.retriever();
+                assertThat(parsed, equalTo(deserialized));
+            }
+        }
+    }
+
+    public void testKnnRetrieverParsing() throws IOException {
+        String restContent = "{"
+            + "  \"retriever\": {"
+            + "    \"knn\": {"
+            + "      \"query_vector\": ["
+            + "        3"
+            + "      ],"
+            + "      \"field\": \"vector\","
+            + "      \"k\": 10,"
+            + "      \"num_candidates\": 15,"
+            + "      \"min_score\": 10,"
+            + "      \"_name\": \"foo_knn\""
+            + "     }"
+            + "  }"
+            + "}";
+        SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
+        try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
+            SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
+            assertThat(source.retriever(), instanceOf(KnnRetrieverBuilder.class));
+            KnnRetrieverBuilder parsed = (KnnRetrieverBuilder) source.retriever();
+            assertThat(parsed.minScore(), equalTo(10f));
+            assertThat(parsed.retrieverName(), equalTo("foo_knn"));
+            try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
+                SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
+                    parseSerialized,
+                    true,
+                    searchUsageHolder,
+                    nf -> true
+                );
+                assertThat(deserializedSource.retriever(), instanceOf(KnnRetrieverBuilder.class));
+                KnnRetrieverBuilder deserialized = (KnnRetrieverBuilder) source.retriever();
+                assertThat(parsed, equalTo(deserialized));
+            }
+        }
+    }
+
     public void testStoredFieldsUsage() throws IOException {
         Set<String> storedFieldRestVariations = Set.of(
             "{\"stored_fields\" : [\"_none_\"]}",

+ 1 - 1
server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java

@@ -74,7 +74,7 @@ public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase<Kn
 
     @Override
     protected KnnRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
-        return KnnRetrieverBuilder.fromXContent(
+        return (KnnRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
             parser,
             new RetrieverParserContext(
                 new SearchUsage(),

+ 1 - 1
server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java

@@ -98,7 +98,7 @@ public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCa
 
     @Override
     protected StandardRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
-        return StandardRetrieverBuilder.fromXContent(
+        return (StandardRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
             parser,
             new RetrieverParserContext(
                 new SearchUsage(),

+ 1 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java

@@ -103,10 +103,7 @@ public class RandomRankRetrieverBuilder extends RetrieverBuilder {
 
     @Override
     protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.field(RETRIEVER_FIELD.getPreferredName());
-        builder.startObject();
-        builder.field(retrieverBuilder.getName(), retrieverBuilder);
-        builder.endObject();
+        builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
         builder.field(FIELD_FIELD.getPreferredName(), field);
         builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
         if (seed != null) {

+ 1 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -179,17 +179,11 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
 
     @Override
     protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.field(RETRIEVER_FIELD.getPreferredName());
-        builder.startObject();
-        builder.field(retrieverBuilder.getName(), retrieverBuilder);
-        builder.endObject();
+        builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
         builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
         builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
         builder.field(FIELD_FIELD.getPreferredName(), field);
         builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
-        if (minScore != null) {
-            builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
-        }
     }
 
     @Override

+ 4 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java

@@ -17,8 +17,6 @@ import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.json.JsonXContent;
-import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
-import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -48,8 +46,8 @@ public class RandomRankRetrieverBuilderTests extends AbstractXContentTestCase<Ra
     }
 
     @Override
-    protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) {
-        return RandomRankRetrieverBuilder.PARSER.apply(
+    protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
+        return (RandomRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
             parser,
             new RetrieverParserContext(
                 new SearchUsage(),
@@ -77,8 +75,8 @@ public class RandomRankRetrieverBuilderTests extends AbstractXContentTestCase<Ra
         entries.add(
             new NamedXContentRegistry.Entry(
                 RetrieverBuilder.class,
-                new ParseField(TextSimilarityRankBuilder.NAME),
-                (p, c) -> TextSimilarityRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
+                new ParseField(RandomRankBuilder.NAME),
+                (p, c) -> RandomRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
             )
         );
         return new NamedXContentRegistry(entries);

+ 44 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.rank.textsimilarity;
 
 import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
@@ -25,6 +26,8 @@ import org.elasticsearch.search.retriever.TestRetrieverBuilder;
 import org.elasticsearch.test.AbstractXContentTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.usage.SearchUsage;
+import org.elasticsearch.usage.SearchUsageHolder;
+import org.elasticsearch.usage.UsageService;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentParser;
@@ -72,8 +75,8 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
     }
 
     @Override
-    protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) {
-        return TextSimilarityRankRetrieverBuilder.PARSER.apply(
+    protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
+        return (TextSimilarityRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
             parser,
             new RetrieverParserContext(
                 new SearchUsage(),
@@ -208,6 +211,45 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
         }
     }
 
+    public void testTextSimilarityRetrieverParsing() throws IOException {
+        String restContent = "{"
+            + "  \"retriever\": {"
+            + "    \"text_similarity_reranker\": {"
+            + "      \"retriever\": {"
+            + "        \"test\": {"
+            + "          \"value\": \"my-test-retriever\""
+            + "        }"
+            + "      },"
+            + "      \"field\": \"my-field\","
+            + "      \"inference_id\": \"my-inference-id\","
+            + "      \"inference_text\": \"my-inference-text\","
+            + "      \"rank_window_size\": 100,"
+            + "      \"min_score\": 20.0,"
+            + "      \"_name\": \"foo_reranker\""
+            + "    }"
+            + "  }"
+            + "}";
+        SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
+        try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
+            SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
+            assertThat(source.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
+            TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever();
+            assertThat(parsed.minScore(), equalTo(20f));
+            assertThat(parsed.retrieverName(), equalTo("foo_reranker"));
+            try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
+                SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
+                    parseSerialized,
+                    true,
+                    searchUsageHolder,
+                    nf -> true
+                );
+                assertThat(deserializedSource.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
+                TextSimilarityRankRetrieverBuilder deserialized = (TextSimilarityRankRetrieverBuilder) source.retriever();
+                assertThat(parsed, equalTo(deserialized));
+            }
+        }
+    }
+
     public void testIsCompound() {
         RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) {
             @Override

+ 0 - 3
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

@@ -180,10 +180,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
             builder.startArray(RETRIEVERS_FIELD.getPreferredName());
 
             for (var entry : innerRetrievers) {
-                builder.startObject();
-                builder.field(entry.retriever().getName());
                 entry.retriever().toXContent(builder, params);
-                builder.endObject();
             }
             builder.endArray();
         }

+ 56 - 1
x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

@@ -8,19 +8,27 @@
 package org.elasticsearch.xpack.rank.rrf;
 
 import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
 import org.elasticsearch.search.retriever.TestRetrieverBuilder;
 import org.elasticsearch.test.AbstractXContentTestCase;
 import org.elasticsearch.usage.SearchUsage;
+import org.elasticsearch.usage.SearchUsageHolder;
+import org.elasticsearch.usage.UsageService;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.json.JsonXContent;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+
 public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RRFRetrieverBuilder> {
 
     /**
@@ -53,7 +61,10 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RR
 
     @Override
     protected RRFRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
-        return RRFRetrieverBuilder.PARSER.apply(parser, new RetrieverParserContext(new SearchUsage(), nf -> true));
+        return (RRFRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
+            parser,
+            new RetrieverParserContext(new SearchUsage(), nf -> true)
+        );
     }
 
     @Override
@@ -81,4 +92,48 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RR
         );
         return new NamedXContentRegistry(entries);
     }
+
+    public void testRRFRetrieverParsing() throws IOException {
+        String restContent = "{"
+            + "  \"retriever\": {"
+            + "    \"rrf\": {"
+            + "      \"retrievers\": ["
+            + "        {"
+            + "          \"test\": {"
+            + "            \"value\": \"foo\""
+            + "          }"
+            + "        },"
+            + "        {"
+            + "          \"test\": {"
+            + "            \"value\": \"bar\""
+            + "          }"
+            + "        }"
+            + "      ],"
+            + "      \"rank_window_size\": 100,"
+            + "      \"rank_constant\": 10,"
+            + "      \"min_score\": 20.0,"
+            + "      \"_name\": \"foo_rrf\""
+            + "    }"
+            + "  }"
+            + "}";
+        SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
+        try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
+            SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
+            assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class));
+            RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever();
+            assertThat(parsed.minScore(), equalTo(20f));
+            assertThat(parsed.retrieverName(), equalTo("foo_rrf"));
+            try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
+                SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
+                    parseSerialized,
+                    true,
+                    searchUsageHolder,
+                    nf -> true
+                );
+                assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class));
+                RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever();
+                assertThat(parsed, equalTo(deserialized));
+            }
+        }
+    }
 }