Browse Source

Add searchScroll method to high level REST client (#24938)

Luca Cavanna 8 years ago
parent
commit
856235fac2

+ 8 - 2
client/rest-high-level/src/main/java/org/elasticsearch/client/Request.java

@@ -34,6 +34,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
 import org.elasticsearch.action.get.GetRequest;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchScrollRequest;
 import org.elasticsearch.action.support.ActiveShardCount;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.WriteRequest;
@@ -63,7 +64,7 @@ import java.util.StringJoiner;
 
 final class Request {
 
-    private static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON;
+    static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON;
 
     final String method;
     final String endpoint;
@@ -338,6 +339,11 @@ final class Request {
         return new Request(HttpGet.METHOD_NAME, endpoint, params.getParams(), entity);
     }
 
+    static Request searchScroll(SearchScrollRequest searchScrollRequest) throws IOException {
+        HttpEntity entity = createEntity(searchScrollRequest, REQUEST_BODY_CONTENT_TYPE);
+        return new Request("GET", "/_search/scroll", Collections.emptyMap(), entity);
+    }
+
     private static HttpEntity createEntity(ToXContent toXContent, XContentType xContentType) throws IOException {
         BytesRef source = XContentHelper.toXContent(toXContent, xContentType, false).toBytesRef();
         return new ByteArrayEntity(source.bytes, source.offset, source.length, ContentType.create(xContentType.mediaType()));
@@ -483,7 +489,7 @@ final class Request {
             return this;
         }
 
-        Params withIndicesOptions (IndicesOptions indicesOptions) {
+        Params withIndicesOptions(IndicesOptions indicesOptions) {
             putParam("ignore_unavailable", Boolean.toString(indicesOptions.ignoreUnavailable()));
             putParam("allow_no_indices", Boolean.toString(indicesOptions.allowNoIndices()));
             String expandWildcards;

+ 23 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java

@@ -38,6 +38,7 @@ import org.elasticsearch.action.main.MainRequest;
 import org.elasticsearch.action.main.MainResponse;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.SearchScrollRequest;
 import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.action.update.UpdateResponse;
 import org.elasticsearch.common.CheckedFunction;
@@ -325,6 +326,27 @@ public class RestHighLevelClient {
         performRequestAsyncAndParseEntity(searchRequest, Request::search, SearchResponse::fromXContent, listener, emptySet(), headers);
     }
 
+    /**
+     * Executes a search using the Search Scroll api
+     *
+     * See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-scroll.html">Search Scroll
+     * API on elastic.co</a>
+     */
+    public SearchResponse searchScroll(SearchScrollRequest searchScrollRequest, Header... headers) throws IOException {
+        return performRequestAndParseEntity(searchScrollRequest, Request::searchScroll, SearchResponse::fromXContent, emptySet(), headers);
+    }
+
+    /**
+     * Asynchronously executes a search using the Search Scroll api
+     *
+     * See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-scroll.html">Search Scroll
+     * API on elastic.co</a>
+     */
+    public void searchScrollAsync(SearchScrollRequest searchScrollRequest, ActionListener<SearchResponse> listener, Header... headers) {
+        performRequestAsyncAndParseEntity(searchScrollRequest, Request::searchScroll, SearchResponse::fromXContent,
+                listener, emptySet(), headers);
+    }
+
     private <Req extends ActionRequest, Resp> Resp performRequestAndParseEntity(Req request,
                                                                             CheckedFunction<Req, Request, IOException> requestConverter,
                                                                             CheckedFunction<XContentParser, Resp, IOException> entityParser,
@@ -354,6 +376,7 @@ public class RestHighLevelClient {
             }
             throw parseResponseException(e);
         }
+
         try {
             return responseConverter.apply(response);
         } catch(Exception e) {

+ 22 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/RequestTests.java

@@ -29,6 +29,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
 import org.elasticsearch.action.get.GetRequest;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchScrollRequest;
 import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.WriteRequest;
@@ -40,6 +41,7 @@ import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.Streams;
 import org.elasticsearch.common.lucene.uid.Versions;
+import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -714,10 +716,27 @@ public class RequestTests extends ESTestCase {
         if (searchSourceBuilder == null) {
             assertNull(request.entity);
         } else {
-            BytesReference expectedBytes = XContentHelper.toXContent(searchSourceBuilder, XContentType.JSON, false);
-            assertEquals(XContentType.JSON.mediaType(), request.entity.getContentType().getValue());
-            assertEquals(expectedBytes, new BytesArray(EntityUtils.toByteArray(request.entity)));
+            assertToXContentBody(searchSourceBuilder, request.entity);
+        }
+    }
+
+    public void testSearchScroll() throws IOException {
+        SearchScrollRequest searchScrollRequest = new SearchScrollRequest();
+        searchScrollRequest.scrollId(randomAlphaOfLengthBetween(5, 10));
+        if (randomBoolean()) {
+            searchScrollRequest.scroll(randomPositiveTimeValue());
         }
+        Request request = Request.searchScroll(searchScrollRequest);
+        assertEquals("GET", request.method);
+        assertEquals("/_search/scroll", request.endpoint);
+        assertEquals(0, request.params.size());
+        assertToXContentBody(searchScrollRequest, request.entity);
+    }
+
+    private static void assertToXContentBody(ToXContent expectedBody, HttpEntity actualEntity) throws IOException {
+        BytesReference expectedBytes = XContentHelper.toXContent(expectedBody, Request.REQUEST_BODY_CONTENT_TYPE, false);
+        assertEquals(XContentType.JSON.mediaType(), actualEntity.getContentType().getValue());
+        assertEquals(expectedBytes, new BytesArray(EntityUtils.toByteArray(actualEntity)));
     }
 
     public void testParams() {

+ 45 - 15
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -33,6 +33,7 @@ import org.apache.http.entity.StringEntity;
 import org.apache.http.message.BasicHttpResponse;
 import org.apache.http.message.BasicRequestLine;
 import org.apache.http.message.BasicStatusLine;
+import org.apache.http.nio.entity.NStringEntity;
 import org.elasticsearch.Build;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
@@ -41,21 +42,26 @@ import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.main.MainRequest;
 import org.elasticsearch.action.main.MainResponse;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.SearchResponseSections;
+import org.elasticsearch.action.search.SearchScrollRequest;
+import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.common.CheckedFunction;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.common.xcontent.cbor.CborXContent;
 import org.elasticsearch.common.xcontent.smile.SmileXContent;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.aggregations.Aggregation;
+import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.suggest.Suggest;
 import org.elasticsearch.test.ESTestCase;
 import org.junit.Before;
 import org.mockito.ArgumentMatcher;
-import org.mockito.Matchers;
 import org.mockito.internal.matchers.ArrayEquals;
 import org.mockito.internal.matchers.VarargMatcher;
 
@@ -68,6 +74,7 @@ import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.elasticsearch.client.RestClientTestUtil.randomHeaders;
 import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.mockito.Matchers.anyMapOf;
@@ -76,6 +83,8 @@ import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.anyVararg;
 import static org.mockito.Matchers.argThat;
 import static org.mockito.Matchers.eq;
+import static org.mockito.Matchers.isNotNull;
+import static org.mockito.Matchers.isNull;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -95,49 +104,70 @@ public class RestHighLevelClientTests extends ESTestCase {
     }
 
     public void testPingSuccessful() throws IOException {
-        Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
+        Header[] headers = randomHeaders(random(), "Header");
         Response response = mock(Response.class);
         when(response.getStatusLine()).thenReturn(newStatusLine(RestStatus.OK));
         when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
                 anyObject(), anyVararg())).thenReturn(response);
         assertTrue(restHighLevelClient.ping(headers));
         verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()),
-                Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
+                isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
     }
 
     public void testPing404NotFound() throws IOException {
-        Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
+        Header[] headers = randomHeaders(random(), "Header");
         Response response = mock(Response.class);
         when(response.getStatusLine()).thenReturn(newStatusLine(RestStatus.NOT_FOUND));
         when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
                 anyObject(), anyVararg())).thenReturn(response);
         assertFalse(restHighLevelClient.ping(headers));
         verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()),
-                Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
+                isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
     }
 
     public void testPingSocketTimeout() throws IOException {
-        Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
+        Header[] headers = randomHeaders(random(), "Header");
         when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
                 anyObject(), anyVararg())).thenThrow(new SocketTimeoutException());
         expectThrows(SocketTimeoutException.class, () -> restHighLevelClient.ping(headers));
         verify(restClient).performRequest(eq("HEAD"), eq("/"), eq(Collections.emptyMap()),
-                Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
+                isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
     }
 
     public void testInfo() throws IOException {
-        Header[] headers = RestClientTestUtil.randomHeaders(random(), "Header");
-        Response response = mock(Response.class);
+        Header[] headers = randomHeaders(random(), "Header");
         MainResponse testInfo = new MainResponse("nodeName", Version.CURRENT, new ClusterName("clusterName"), "clusterUuid",
                 Build.CURRENT, true);
-        when(response.getEntity()).thenReturn(
-                new StringEntity(toXContent(testInfo, XContentType.JSON, false).utf8ToString(), ContentType.APPLICATION_JSON));
-        when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
-                anyObject(), anyVararg())).thenReturn(response);
+        mockResponse(testInfo);
         MainResponse receivedInfo = restHighLevelClient.info(headers);
         assertEquals(testInfo, receivedInfo);
         verify(restClient).performRequest(eq("GET"), eq("/"), eq(Collections.emptyMap()),
-                Matchers.isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
+                isNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
+    }
+
+    public void testSearchScroll() throws IOException {
+        Header[] headers = randomHeaders(random(), "Header");
+        SearchResponse mockSearchResponse = new SearchResponse(new SearchResponseSections(SearchHits.empty(), InternalAggregations.EMPTY,
+                null, false, false, null, 1), randomAlphaOfLengthBetween(5, 10), 5, 5, 100, new ShardSearchFailure[0]);
+        mockResponse(mockSearchResponse);
+        SearchResponse searchResponse = restHighLevelClient.searchScroll(new SearchScrollRequest(randomAlphaOfLengthBetween(5, 10)),
+                headers);
+        assertEquals(mockSearchResponse.getScrollId(), searchResponse.getScrollId());
+        assertEquals(0, searchResponse.getHits().totalHits);
+        assertEquals(5, searchResponse.getTotalShards());
+        assertEquals(5, searchResponse.getSuccessfulShards());
+        assertEquals(100, searchResponse.getTook().getMillis());
+        verify(restClient).performRequest(eq("GET"), eq("/_search/scroll"), eq(Collections.emptyMap()),
+                isNotNull(HttpEntity.class), argThat(new HeadersVarargMatcher(headers)));
+    }
+
+    private void mockResponse(ToXContent toXContent) throws IOException {
+        Response response = mock(Response.class);
+        ContentType contentType = ContentType.parse(Request.REQUEST_BODY_CONTENT_TYPE.mediaType());
+        String requestBody = toXContent(toXContent, Request.REQUEST_BODY_CONTENT_TYPE, false).utf8ToString();
+        when(response.getEntity()).thenReturn(new NStringEntity(requestBody, contentType));
+        when(restClient.performRequest(anyString(), anyString(), anyMapOf(String.class, String.class),
+                anyObject(), anyVararg())).thenReturn(response);
     }
 
     public void testRequestValidation() {