Browse Source

add searchAsync

Signed-off-by: youny626 <zzhu@fandm.edu>
youny626 5 years ago
parent
commit
fa0d1d2956

+ 0 - 2
milvus-sdk-java.iml

@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" version="4" />

+ 27 - 0
src/main/java/io/milvus/client/MilvusClient.java

@@ -19,6 +19,8 @@
 
 package io.milvus.client;
 
+import com.google.common.util.concurrent.ListenableFuture;
+
 import java.util.List;
 
 /** The Milvus Client Interface */
@@ -208,6 +210,31 @@ public interface MilvusClient {
    */
   SearchResponse search(SearchParam searchParam);
 
+  /**
+   * Searches vectors specified by <code>searchParam</code> asynchronously
+   *
+   * @param searchParam the <code>SearchParam</code> object
+   *     <pre>
+   * example usage:
+   * <code>
+   * SearchParam searchParam = new SearchParam.Builder(collectionName)
+   *                                          .withFloatVectors(floatVectors)
+   *                                          .withTopK(topK)
+   *                                          .withPartitionTags(partitionTagsList)
+   *                                          .withParamsInJson("{\"nprobe\": 20}")
+   *                                          .build();
+   * </code>
+   * </pre>
+   *
+   * @return a <code>ListenableFuture</code> object which holds the <code>SearchResponse</code>
+   * @see SearchParam
+   * @see SearchResponse
+   * @see SearchResponse.QueryResult
+   * @see Response
+   * @see ListenableFuture
+   */
+  ListenableFuture<SearchResponse> searchAsync(SearchParam searchParam);
+
   /**
    * Searches vectors in specific files
    *

+ 75 - 0
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -19,14 +19,17 @@
 
 package io.milvus.client;
 
+import com.google.common.util.concurrent.*;
 import com.google.protobuf.ByteString;
 import io.grpc.ConnectivityState;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.StatusRuntimeException;
+import io.milvus.grpc.TopKQueryResult;
 import org.apache.commons.collections4.ListUtils;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 import java.nio.Buffer;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
@@ -46,6 +49,7 @@ public class MilvusGrpcClient implements MilvusClient {
   private final String extraParamKey = "params";
   private ManagedChannel channel = null;
   private io.milvus.grpc.MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub = null;
+  private io.milvus.grpc.MilvusServiceGrpc.MilvusServiceFutureStub futureStub = null;
 
   ////////////////////// Constructor //////////////////////
   public MilvusGrpcClient() {
@@ -101,6 +105,7 @@ public class MilvusGrpcClient implements MilvusClient {
       }
 
       blockingStub = io.milvus.grpc.MilvusServiceGrpc.newBlockingStub(channel);
+      futureStub = io.milvus.grpc.MilvusServiceGrpc.newFutureStub(channel);
 
     } catch (Exception e) {
       if (!(e instanceof ConnectFailedException)) {
@@ -490,6 +495,76 @@ public class MilvusGrpcClient implements MilvusClient {
     }
   }
 
+  @Override
+  public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
+
+    if (!channelIsReadyOrIdle()) {
+      logWarning("You are not connected to Milvus server");
+      SearchResponse searchResponse = new SearchResponse();
+      searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
+      return Futures.immediateFuture(searchResponse);
+    }
+
+    List<io.milvus.grpc.RowRecord> rowRecordList =
+        buildRowRecordList(searchParam.getFloatVectors(), searchParam.getBinaryVectors());
+
+    io.milvus.grpc.KeyValuePair extraParam =
+        io.milvus.grpc.KeyValuePair.newBuilder()
+            .setKey(extraParamKey)
+            .setValue(searchParam.getParamsInJson())
+            .build();
+
+    io.milvus.grpc.SearchParam request =
+        io.milvus.grpc.SearchParam.newBuilder()
+            .setTableName(searchParam.getCollectionName())
+            .addAllQueryRecordArray(rowRecordList)
+            .addAllPartitionTagArray(searchParam.getPartitionTags())
+            .setTopk(searchParam.getTopK())
+            .addExtraParams(extraParam)
+            .build();
+
+    ListenableFuture<io.milvus.grpc.TopKQueryResult> response;
+
+    response = futureStub.search(request);
+
+    Futures.addCallback(
+        response,
+        new FutureCallback<io.milvus.grpc.TopKQueryResult>() {
+          @Override
+          public void onSuccess(@Nullable io.milvus.grpc.TopKQueryResult result) {
+            logInfo("Search completed successfully!");
+          }
+
+          @Override
+          public void onFailure(Throwable t) {
+            logSevere("Search failed:\n{0}", t.getMessage());
+          }
+        },
+        MoreExecutors.directExecutor());
+
+    com.google.common.base.Function<TopKQueryResult, SearchResponse> transformFunc =
+        new com.google.common.base.Function<io.milvus.grpc.TopKQueryResult, SearchResponse>() {
+          @Override
+          public SearchResponse apply(io.milvus.grpc.TopKQueryResult topKQueryResult) {
+
+            if (topKQueryResult.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
+              SearchResponse searchResponse = buildSearchResponse(topKQueryResult);
+              searchResponse.setResponse(new Response(Response.Status.SUCCESS));
+              return searchResponse;
+            } else {
+              SearchResponse searchResponse = new SearchResponse();
+              searchResponse.setResponse(
+                  new Response(
+                      Response.Status.valueOf(topKQueryResult.getStatus().getErrorCodeValue()),
+                      topKQueryResult.getStatus().getReason()));
+              return searchResponse;
+            }
+          }
+        };
+
+    return Futures.transform(response, transformFunc, MoreExecutors.directExecutor());
+  }
+
   @Override
   public SearchResponse searchInFiles(
       @Nonnull List<String> fileIds, @Nonnull SearchParam searchParam) {