1
0
Эх сурвалжийг харах

Support `search_type` in Rank Evaluation API (#48542)

Adding support for the `search_type` request parameter to the Ranking Evaluation
API since this parameter can impact the ranking and the metric score and should
be choosen in the same way when evaluating the search as later in the real
search.

Closes #48503
Christoph Büscher 6 жил өмнө
parent
commit
e5646fefa3

+ 1 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/RequestConverters.java

@@ -520,6 +520,7 @@ final class RequestConverters {
 
         Params params = new Params();
         params.withIndicesOptions(rankEvalRequest.indicesOptions());
+        params.putParam("search_type", rankEvalRequest.searchType().name().toLowerCase(Locale.ROOT));
         request.addParameters(params.asMap());
         request.setEntity(createEntity(rankEvalRequest.getRankEvalSpec(), REQUEST_BODY_CONTENT_TYPE));
         return request;

+ 5 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/RequestConvertersTests.java

@@ -1480,6 +1480,10 @@ public class RequestConvertersTests extends ESTestCase {
         RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, indices);
         Map<String, String> expectedParams = new HashMap<>();
         setRandomIndicesOptions(rankEvalRequest::indicesOptions, rankEvalRequest::indicesOptions, expectedParams);
+        if (randomBoolean()) {
+            rankEvalRequest.searchType(randomFrom(SearchType.CURRENTLY_SUPPORTED));
+        }
+        expectedParams.put("search_type", rankEvalRequest.searchType().name().toLowerCase(Locale.ROOT));
 
         Request request = RequestConverters.rankEval(rankEvalRequest);
         StringJoiner endpoint = new StringJoiner("/", "/", "");
@@ -1489,7 +1493,7 @@ public class RequestConvertersTests extends ESTestCase {
         }
         endpoint.add(RestRankEvalAction.ENDPOINT);
         assertEquals(endpoint.toString(), request.getEndpoint());
-        assertEquals(4, request.getParameters().size());
+        assertEquals(5, request.getParameters().size());
         assertEquals(expectedParams, request.getParameters());
         assertToXContentBody(spec, request.getEntity());
     }

+ 27 - 2
modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalRequest.java

@@ -19,10 +19,12 @@
 
 package org.elasticsearch.index.rankeval;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.IndicesRequest;
 import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -42,6 +44,8 @@ public class RankEvalRequest extends ActionRequest implements IndicesRequest.Rep
     private IndicesOptions indicesOptions  = SearchRequest.DEFAULT_INDICES_OPTIONS;
     private String[] indices = Strings.EMPTY_ARRAY;
 
+    private SearchType searchType = SearchType.DEFAULT;
+
     public RankEvalRequest(RankEvalSpec rankingEvaluationSpec, String[] indices) {
         this.rankingEvaluationSpec = Objects.requireNonNull(rankingEvaluationSpec, "ranking evaluation specification must not be null");
         indices(indices);
@@ -52,6 +56,9 @@ public class RankEvalRequest extends ActionRequest implements IndicesRequest.Rep
         rankingEvaluationSpec = new RankEvalSpec(in);
         indices = in.readStringArray();
         indicesOptions = IndicesOptions.readIndicesOptions(in);
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            searchType = SearchType.fromId(in.readByte());
+        }
     }
 
     RankEvalRequest() {
@@ -111,12 +118,29 @@ public class RankEvalRequest extends ActionRequest implements IndicesRequest.Rep
         this.indicesOptions = Objects.requireNonNull(indicesOptions, "indicesOptions must not be null");
     }
 
+    /**
+     * The search type to execute, defaults to {@link SearchType#DEFAULT}.
+     */
+    public void searchType(SearchType searchType) {
+        this.searchType = Objects.requireNonNull(searchType, "searchType must not be null");
+    }
+
+    /**
+     * The type of search to execute.
+     */
+    public SearchType searchType() {
+        return searchType;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
         rankingEvaluationSpec.writeTo(out);
         out.writeStringArray(indices);
         indicesOptions.writeIndicesOptions(out);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeByte(searchType.id());
+        }
     }
 
     @Override
@@ -130,11 +154,12 @@ public class RankEvalRequest extends ActionRequest implements IndicesRequest.Rep
         RankEvalRequest that = (RankEvalRequest) o;
         return Objects.equals(indicesOptions, that.indicesOptions) &&
                 Arrays.equals(indices, that.indices) &&
-                Objects.equals(rankingEvaluationSpec, that.rankingEvaluationSpec);
+                Objects.equals(rankingEvaluationSpec, that.rankingEvaluationSpec) &&
+                Objects.equals(searchType, that.searchType);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(indicesOptions, Arrays.hashCode(indices), rankingEvaluationSpec);
+        return Objects.hash(indicesOptions, Arrays.hashCode(indices), rankingEvaluationSpec, searchType);
     }
 }

+ 4 - 0
modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RestRankEvalAction.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.index.rankeval;
 
+import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.common.Strings;
@@ -109,6 +110,9 @@ public class RestRankEvalAction extends BaseRestHandler {
     private static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, XContentParser parser) {
         rankEvalRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
         rankEvalRequest.indicesOptions(IndicesOptions.fromRequest(request, rankEvalRequest.indicesOptions()));
+        if (request.hasParam("search_type")) {
+            rankEvalRequest.searchType(SearchType.fromString(request.param("search_type")));
+        }
         RankEvalSpec spec = RankEvalSpec.parse(parser);
         rankEvalRequest.setRankEvalSpec(spec);
     }

+ 1 - 0
modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/TransportRankEvalAction.java

@@ -129,6 +129,7 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
             }
             SearchRequest searchRequest = new SearchRequest(request.indices(), evaluationRequest);
             searchRequest.indicesOptions(request.indicesOptions());
+            searchRequest.searchType(request.searchType());
             msearchRequest.add(searchRequest);
         }
         assert ratedRequestsInSearch.size() == msearchRequest.requests().size();

+ 11 - 0
modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalRequestTests.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.index.rankeval;
 
+import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable.Reader;
@@ -62,6 +63,7 @@ public class RankEvalRequestTests extends AbstractWireSerializingTestCase<RankEv
                 randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(),
             randomBoolean());
         rankEvalRequest.indicesOptions(indicesOptions);
+        rankEvalRequest.searchType(randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH));
         return rankEvalRequest;
     }
 
@@ -77,8 +79,17 @@ public class RankEvalRequestTests extends AbstractWireSerializingTestCase<RankEv
         mutators.add(() -> mutation.indices(ArrayUtils.concat(instance.indices(), new String[] { randomAlphaOfLength(10) })));
         mutators.add(() -> mutation.indicesOptions(randomValueOtherThan(instance.indicesOptions(),
                 () -> IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()))));
+        mutators.add(() -> {
+            if (instance.searchType() == SearchType.DFS_QUERY_THEN_FETCH) {
+                mutation.searchType(SearchType.QUERY_THEN_FETCH);
+            } else {
+                mutation.searchType(SearchType.DFS_QUERY_THEN_FETCH);
+            }
+        });
+        mutators.add(() -> mutation.setRankEvalSpec(RankEvalSpecTests.mutateTestItem(instance.getRankEvalSpec())));
         mutators.add(() -> mutation.setRankEvalSpec(RankEvalSpecTests.mutateTestItem(instance.getRankEvalSpec())));
         randomFrom(mutators).run();
         return mutation;
     }
+
 }

+ 78 - 0
modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/TransportRankEvalActionTests.java

@@ -0,0 +1,78 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.index.rankeval;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.MultiSearchRequest;
+import org.elasticsearch.action.search.MultiSearchResponse;
+import org.elasticsearch.action.search.SearchType;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.env.Environment;
+import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.transport.TransportService;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.mockito.Mockito.mock;
+
+public class TransportRankEvalActionTests extends ESTestCase {
+
+    private Settings settings = Settings.builder().put("path.home", createTempDir().toString()).put("node.name", "test-" + getTestName())
+            .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build();
+
+    /**
+     * Test that request parameters like indicesOptions or searchType from ranking evaluation request are transfered to msearch request
+     */
+    public void testTransferRequestParameters() throws Exception {
+        String indexName = "test_index";
+        List<RatedRequest> specifications = new ArrayList<>();
+        specifications
+                .add(new RatedRequest("amsterdam_query", Arrays.asList(new RatedDocument(indexName, "1", 3)), new SearchSourceBuilder()));
+        RankEvalRequest rankEvalRequest = new RankEvalRequest(new RankEvalSpec(specifications, new DiscountedCumulativeGain()),
+                new String[] { indexName });
+        SearchType expectedSearchType = randomFrom(SearchType.CURRENTLY_SUPPORTED);
+        rankEvalRequest.searchType(expectedSearchType);
+        IndicesOptions expectedIndicesOptions = IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(),
+                randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean());
+        rankEvalRequest.indicesOptions(expectedIndicesOptions);
+
+        NodeClient client = new NodeClient(settings, null) {
+            @Override
+            public void multiSearch(MultiSearchRequest request, ActionListener<MultiSearchResponse> listener) {
+                assertEquals(1, request.requests().size());
+                assertEquals(expectedSearchType, request.requests().get(0).searchType());
+                assertArrayEquals(new String[]{indexName}, request.requests().get(0).indices());
+                assertEquals(expectedIndicesOptions, request.requests().get(0).indicesOptions());
+            }
+        };
+
+        TransportRankEvalAction action = new TransportRankEvalAction(mock(ActionFilters.class), client, mock(TransportService.class),
+                mock(ScriptService.class), NamedXContentRegistry.EMPTY);
+        action.doExecute(null, rankEvalRequest, null);
+    }
+}

+ 1 - 0
modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml

@@ -43,6 +43,7 @@ setup:
   - do:
       rank_eval:
         index: foo,
+        search_type: query_then_fetch
         body: {
           "requests" : [
             {

+ 8 - 0
rest-api-spec/src/main/resources/rest-api-spec/api/rank_eval.json

@@ -48,6 +48,14 @@
         ],
         "default":"open",
         "description":"Whether to expand wildcard expression to concrete indices that are open, closed or both."
+      },
+      "search_type":{
+        "type":"enum",
+        "options":[
+          "query_then_fetch",
+          "dfs_query_then_fetch"
+        ],
+        "description":"Search operation type"
       }
     },
     "body":{