1
0
Эх сурвалжийг харах

Switch to using RescoreBuilder in SearchSourceBuilder

Christoph Büscher 9 жил өмнө
parent
commit
1550d0f013

+ 15 - 8
core/src/main/java/org/elasticsearch/search/SearchModule.java

@@ -19,14 +19,6 @@
 
 package org.elasticsearch.search;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.function.Supplier;
-
 import org.apache.lucene.search.BooleanQuery;
 import org.elasticsearch.common.geo.ShapesAvailability;
 import org.elasticsearch.common.geo.builders.CircleBuilder;
@@ -227,9 +219,19 @@ import org.elasticsearch.search.highlight.HighlightPhase;
 import org.elasticsearch.search.highlight.Highlighter;
 import org.elasticsearch.search.highlight.Highlighters;
 import org.elasticsearch.search.query.QueryPhase;
+import org.elasticsearch.search.rescore.QueryRescorerBuilder;
+import org.elasticsearch.search.rescore.RescoreBuilder;
 import org.elasticsearch.search.suggest.Suggester;
 import org.elasticsearch.search.suggest.Suggesters;
 
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Supplier;
+
 /**
  *
  */
@@ -327,6 +329,7 @@ public class SearchModule extends AbstractModule {
         bind(IndicesQueriesRegistry.class).toInstance(buildQueryParserRegistry());
         configureFetchSubPhase();
         configureShapes();
+        configureRescorers();
     }
 
     protected void configureFetchSubPhase() {
@@ -467,6 +470,10 @@ public class SearchModule extends AbstractModule {
         }
     }
 
+    private void configureRescorers() {
+        namedWriteableRegistry.registerPrototype(RescoreBuilder.class, QueryRescorerBuilder.PROTOTYPE);
+    }
+
     private void registerBuiltinFunctionScoreParsers() {
         registerFunctionScoreParser(new ScriptScoreFunctionParser());
         registerFunctionScoreParser(new GaussDecayFunctionParser());

+ 6 - 25
core/src/main/java/org/elasticsearch/search/SearchService.java

@@ -23,6 +23,7 @@ import com.carrotsearch.hppc.ObjectFloatHashMap;
 import com.carrotsearch.hppc.ObjectHashSet;
 import com.carrotsearch.hppc.ObjectSet;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
+
 import org.apache.lucene.index.IndexOptions;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.NumericDocValues;
@@ -100,6 +101,7 @@ import org.elasticsearch.search.query.QuerySearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.query.QuerySearchResultProvider;
 import org.elasticsearch.search.query.ScrollQuerySearchResult;
+import org.elasticsearch.search.rescore.RescoreBaseBuilder;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.io.IOException;
@@ -772,33 +774,12 @@ public class SearchService extends AbstractLifecycleComponent<SearchService> imp
             }
         }
         if (source.rescores() != null) {
-            XContentParser completeRescoreParser = null;
             try {
-                XContentBuilder completeRescoreBuilder = XContentFactory.jsonBuilder();
-                completeRescoreBuilder.startObject();
-                completeRescoreBuilder.startArray("rescore");
-                for (BytesReference rescore : source.rescores()) {
-                    XContentParser parser = XContentFactory.xContent(rescore).createParser(rescore);
-                    parser.nextToken();
-                    completeRescoreBuilder.copyCurrentStructure(parser);
+                for (RescoreBaseBuilder rescore : source.rescores()) {
+                    context.addRescore(rescore.build(context.indexShard().getQueryShardContext()));
                 }
-                completeRescoreBuilder.endArray();
-                completeRescoreBuilder.endObject();
-                BytesReference completeRescoreBytes = completeRescoreBuilder.bytes();
-                completeRescoreParser = XContentFactory.xContent(completeRescoreBytes).createParser(completeRescoreBytes);
-                completeRescoreParser.nextToken();
-                completeRescoreParser.nextToken();
-                completeRescoreParser.nextToken();
-                this.elementParsers.get("rescore").parse(completeRescoreParser, context);
-            } catch (Exception e) {
-                String sSource = "_na_";
-                try {
-                    sSource = source.toString();
-                } catch (Throwable e1) {
-                    // ignore
-                }
-                XContentLocation location = completeRescoreParser != null ? completeRescoreParser.getTokenLocation() : null;
-                throw new SearchParseException(context, "failed to parse rescore source [" + sSource + "]", location, e);
+            } catch (IOException e) {
+                throw new SearchContextException(context, "failed to create RescoreSearchContext", e);
             }
         }
         if (source.fields() != null) {

+ 12 - 22
core/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java

@@ -21,6 +21,7 @@ package org.elasticsearch.search.builder;
 
 import com.carrotsearch.hppc.ObjectFloatHashMap;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
+
 import org.elasticsearch.Version;
 import org.elasticsearch.action.support.ToXContentToBytes;
 import org.elasticsearch.common.Nullable;
@@ -151,7 +152,7 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
 
     private BytesReference innerHitsBuilder;
 
-    private List<BytesReference> rescoreBuilders;
+    private List<RescoreBaseBuilder> rescoreBuilders;
 
     private ObjectFloatHashMap<String> indexBoost = null;
 
@@ -459,19 +460,11 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
     }
 
     public SearchSourceBuilder addRescorer(RescoreBaseBuilder rescoreBuilder) {
-        try {
             if (rescoreBuilders == null) {
                 rescoreBuilders = new ArrayList<>();
             }
-            XContentBuilder builder = XContentFactory.jsonBuilder();
-            builder.startObject();
-            rescoreBuilder.toXContent(builder, EMPTY_PARAMS);
-            builder.endObject();
-            rescoreBuilders.add(builder.bytes());
+            rescoreBuilders.add(rescoreBuilder);
             return this;
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
     }
 
     public SearchSourceBuilder clearRescorers() {
@@ -498,7 +491,7 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
     /**
      * Gets the bytes representing the rescore builders for this request.
      */
-    public List<BytesReference> rescores() {
+    public List<RescoreBaseBuilder> rescores() {
         return rescoreBuilders;
     }
 
@@ -878,10 +871,9 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
                     }
                     builder.sorts = sorts;
                 } else if (context.parseFieldMatcher().match(currentFieldName, RESCORE_FIELD)) {
-                    List<BytesReference> rescoreBuilders = new ArrayList<>();
+                    List<RescoreBaseBuilder> rescoreBuilders = new ArrayList<>();
                     while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
-                        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().copyCurrentStructure(parser);
-                        rescoreBuilders.add(xContentBuilder.bytes());
+                        rescoreBuilders.add(RescoreBaseBuilder.PROTOTYPE.fromXContent(context));
                     }
                     builder.rescoreBuilders = rescoreBuilders;
                 } else if (context.parseFieldMatcher().match(currentFieldName, STATS_FIELD)) {
@@ -1048,10 +1040,8 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
 
         if (rescoreBuilders != null) {
             builder.startArray(RESCORE_FIELD.getPreferredName());
-            for (BytesReference rescoreBuilder : rescoreBuilders) {
-                XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(rescoreBuilder);
-                parser.nextToken();
-                builder.copyCurrentStructure(parser);
+            for (RescoreBaseBuilder rescoreBuilder : rescoreBuilders) {
+                rescoreBuilder.toXContent(builder, params);
             }
             builder.endArray();
         }
@@ -1197,9 +1187,9 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
         }
         if (in.readBoolean()) {
             int size = in.readVInt();
-            List<BytesReference> rescoreBuilders = new ArrayList<>();
+            List<RescoreBaseBuilder> rescoreBuilders = new ArrayList<>();
             for (int i = 0; i < size; i++) {
-                rescoreBuilders.add(in.readBytesReference());
+                rescoreBuilders.add(RescoreBaseBuilder.PROTOTYPE.readFrom(in));
             }
             builder.rescoreBuilders = rescoreBuilders;
         }
@@ -1313,8 +1303,8 @@ public final class SearchSourceBuilder extends ToXContentToBytes implements Writ
         out.writeBoolean(hasRescoreBuilders);
         if (hasRescoreBuilders) {
             out.writeVInt(rescoreBuilders.size());
-            for (BytesReference rescoreBuilder : rescoreBuilders) {
-                out.writeBytesReference(rescoreBuilder);
+            for (RescoreBaseBuilder rescoreBuilder : rescoreBuilders) {
+                rescoreBuilder.writeTo(out);
             }
         }
         boolean hasScriptFields = scriptFields != null;

+ 2 - 0
core/src/main/java/org/elasticsearch/search/rescore/RescoreBaseBuilder.java

@@ -115,10 +115,12 @@ public class RescoreBaseBuilder implements ToXContent, Writeable<RescoreBaseBuil
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
         if (windowSize != null) {
             builder.field("window_size", windowSize);
         }
         rescorer.toXContent(builder, params);
+        builder.endObject();
         return builder;
     }
 

+ 7 - 10
core/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java

@@ -19,11 +19,6 @@
 
 package org.elasticsearch.search.builder;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.TimeUnit;
-
 import org.elasticsearch.common.ParseFieldMatcher;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
@@ -57,7 +52,7 @@ import org.elasticsearch.search.fetch.innerhits.InnerHitsBuilder;
 import org.elasticsearch.search.fetch.innerhits.InnerHitsBuilder.InnerHit;
 import org.elasticsearch.search.fetch.source.FetchSourceContext;
 import org.elasticsearch.search.highlight.HighlightBuilderTests;
-import org.elasticsearch.search.rescore.RescoreBaseBuilder;
+import org.elasticsearch.search.rescore.QueryRescoreBuilderTests;
 import org.elasticsearch.search.sort.SortBuilders;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.search.suggest.SuggestBuilder;
@@ -68,6 +63,11 @@ import org.elasticsearch.threadpool.ThreadPoolModule;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
 import static org.hamcrest.Matchers.equalTo;
 
 public class SearchSourceBuilderTests extends ESTestCase {
@@ -280,10 +280,7 @@ public class SearchSourceBuilderTests extends ESTestCase {
         if (randomBoolean()) {
             int numRescores = randomIntBetween(1, 5);
             for (int i = 0; i < numRescores; i++) {
-                // NORELEASE need a random rescore builder method
-                RescoreBaseBuilder rescoreBuilder = new RescoreBaseBuilder(RescoreBaseBuilder.queryRescorer(QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20),
-                        randomAsciiOfLengthBetween(5, 20))));
-                builder.addRescorer(rescoreBuilder);
+                builder.addRescorer(QueryRescoreBuilderTests.randomRescoreBuilder());
             }
         }
         if (randomBoolean()) {

+ 7 - 7
core/src/test/java/org/elasticsearch/search/highlight/HighlightBuilderTests.java

@@ -19,13 +19,6 @@
 
 package org.elasticsearch.search.highlight;
 
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.common.ParseFieldMatcher;
@@ -64,6 +57,13 @@ import org.elasticsearch.test.IndexSettingsModule;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
 

+ 4 - 13
core/src/test/java/org/elasticsearch/search/rescore/QueryRescoreBuilderTests.java

@@ -43,12 +43,11 @@ import org.elasticsearch.index.mapper.Mapper;
 import org.elasticsearch.index.mapper.MapperBuilders;
 import org.elasticsearch.index.mapper.core.StringFieldMapper;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
-import org.elasticsearch.index.query.MatchAllQueryParser;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryParseContext;
-import org.elasticsearch.index.query.QueryParser;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.indices.query.IndicesQueriesRegistry;
+import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.rescore.QueryRescorer.QueryRescoreContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.IndexSettingsModule;
@@ -56,8 +55,6 @@ import org.junit.AfterClass;
 import org.junit.BeforeClass;
 
 import java.io.IOException;
-import java.util.HashSet;
-import java.util.Set;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
@@ -74,11 +71,8 @@ public class QueryRescoreBuilderTests extends ESTestCase {
     @BeforeClass
     public static void init() {
         namedWriteableRegistry = new NamedWriteableRegistry();
-        namedWriteableRegistry.registerPrototype(RescoreBuilder.class, org.elasticsearch.search.rescore.QueryRescorerBuilder.PROTOTYPE);
-        @SuppressWarnings("rawtypes")
-        Set<QueryParser> injectedQueryParsers = new HashSet<>();
-        injectedQueryParsers.add(new MatchAllQueryParser());
-        indicesQueriesRegistry = new IndicesQueriesRegistry(Settings.settingsBuilder().build(), injectedQueryParsers, namedWriteableRegistry);
+        namedWriteableRegistry.registerPrototype(RescoreBuilder.class, QueryRescorerBuilder.PROTOTYPE);
+        indicesQueriesRegistry = new SearchModule(Settings.EMPTY, namedWriteableRegistry).buildQueryParserRegistry();
     }
 
     @AfterClass
@@ -154,10 +148,7 @@ public class QueryRescoreBuilderTests extends ESTestCase {
         if (randomBoolean()) {
             builder.prettyPrint();
         }
-        builder.startObject();
         rescoreBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
-        builder.endObject();
-
         return XContentHelper.createParser(builder.bytes());
     }
 
@@ -326,7 +317,7 @@ public class QueryRescoreBuilderTests extends ESTestCase {
     /**
      * create random shape that is put under test
      */
-    private static RescoreBaseBuilder randomRescoreBuilder() {
+    public static RescoreBaseBuilder randomRescoreBuilder() {
         QueryBuilder<MatchAllQueryBuilder> queryBuilder = new MatchAllQueryBuilder().boost(randomFloat()).queryName(randomAsciiOfLength(20));
         org.elasticsearch.search.rescore.QueryRescorerBuilder rescorer = new
                 org.elasticsearch.search.rescore.QueryRescorerBuilder(queryBuilder);