Browse Source

add searchAsync unit test

Signed-off-by: youny626 <zzhu@fandm.edu>
youny626 5 years ago
parent
commit
bbe20fe0d5

+ 2 - 2
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -531,7 +531,7 @@ public class MilvusGrpcClient implements MilvusClient {
         response,
         new FutureCallback<io.milvus.grpc.TopKQueryResult>() {
           @Override
-          public void onSuccess(@Nullable io.milvus.grpc.TopKQueryResult result) {
+          public void onSuccess(io.milvus.grpc.TopKQueryResult result) {
             if (result.getStatus().getErrorCode() == io.milvus.grpc.ErrorCode.SUCCESS) {
               logInfo(
                   "SearchAsync completed successfully! Returned results for {0} queries",
@@ -548,7 +548,7 @@ public class MilvusGrpcClient implements MilvusClient {
         },
         MoreExecutors.directExecutor());
 
-    com.google.common.base.Function<TopKQueryResult, SearchResponse> transformFunc =
+    com.google.common.base.Function<io.milvus.grpc.TopKQueryResult, SearchResponse> transformFunc =
         new com.google.common.base.Function<io.milvus.grpc.TopKQueryResult, SearchResponse>() {
           @Override
           public SearchResponse apply(io.milvus.grpc.TopKQueryResult topKQueryResult) {

+ 44 - 0
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -19,10 +19,12 @@
 
 package io.milvus.client;
 
+import com.google.common.util.concurrent.ListenableFuture;
 import org.apache.commons.text.RandomStringGenerator;
 
 import java.nio.ByteBuffer;
 import java.util.*;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.DoubleStream;
@@ -361,6 +363,48 @@ class MilvusClientTest {
     }
   }
 
+  @org.junit.jupiter.api.Test
+  void searchAsync() throws ExecutionException, InterruptedException {
+    List<List<Float>> vectors = generateFloatVectors(size, dimension);
+    vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
+    InsertParam insertParam =
+        new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
+    InsertResponse insertResponse = client.insert(insertParam);
+    assertTrue(insertResponse.ok());
+    List<Long> vectorIds = insertResponse.getVectorIds();
+    assertEquals(size, vectorIds.size());
+
+    assertTrue(client.flush(randomCollectionName).ok());
+
+    final int searchSize = 5;
+    List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
+
+    final long topK = 10;
+    SearchParam searchParam =
+        new SearchParam.Builder(randomCollectionName)
+            .withFloatVectors(vectorsToSearch)
+            .withTopK(topK)
+            .withParamsInJson("{\"nprobe\": 20}")
+            .build();
+    ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
+    SearchResponse searchResponse = searchResponseFuture.get();
+    assertTrue(searchResponse.ok());
+    List<List<Long>> resultIdsList = searchResponse.getResultIdsList();
+    assertEquals(searchSize, resultIdsList.size());
+    List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
+    assertEquals(searchSize, resultDistancesList.size());
+    List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
+    assertEquals(searchSize, queryResultsList.size());
+    final double epsilon = 0.001;
+    for (int i = 0; i < searchSize; i++) {
+      SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
+      assertEquals(vectorIds.get(i), firstQueryResult.getVectorId());
+      assertEquals(vectorIds.get(i), resultIdsList.get(i).get(0));
+      assertTrue(Math.abs(1 - firstQueryResult.getDistance()) < epsilon);
+      assertTrue(Math.abs(1 - resultDistancesList.get(i).get(0)) < epsilon);
+    }
+  }
+
   @org.junit.jupiter.api.Test
   void searchBinary() {
     final long binaryDimension = 10000;