Browse Source

Remove id=-1 from topK query output

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>
sahuang 5 năm trước cách đây
mục cha
commit
350b656b85

+ 13 - 5
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -669,11 +669,11 @@ public class MilvusGrpcClient implements MilvusClient {
         SearchResponse searchResponse = buildSearchResponse(response);
         searchResponse.setResponse(new Response(Response.Status.SUCCESS));
         logInfo(
-                "Search completed successfully! Returned results for {0} queries",
+                "Search by ids completed successfully! Returned results for {0} queries",
                 searchResponse.getNumQueries());
         return searchResponse;
       } else {
-        logSevere("Search failed:\n{0}", response.getStatus().toString());
+        logSevere("Search by ids failed:\n{0}", response.getStatus().toString());
         SearchResponse searchResponse = new SearchResponse();
         searchResponse.setResponse(
                 new Response(
@@ -682,7 +682,7 @@ public class MilvusGrpcClient implements MilvusClient {
         return searchResponse;
       }
     } catch (StatusRuntimeException e) {
-      logSevere("search RPC failed:\n{0}", e.getStatus().toString());
+      logSevere("search by ids RPC failed:\n{0}", e.getStatus().toString());
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.RPC_ERROR, e.toString()));
       return searchResponse;
@@ -1432,9 +1432,17 @@ public class MilvusGrpcClient implements MilvusClient {
 
     List<List<Long>> resultIdsList = new ArrayList<>();
     List<List<Float>> resultDistancesList = new ArrayList<>();
+
     if (topK > 0) {
-      resultIdsList = ListUtils.partition(topKQueryResult.getIdsList(), topK);
-      resultDistancesList = ListUtils.partition(topKQueryResult.getDistancesList(), topK);
+      for (int i = 0; i < numQueries; i++) {
+        // Process result of query i
+        int pos = i * topK;
+        while (pos < i * topK + topK && topKQueryResult.getIdsList().get(pos) != -1) {
+          pos++;
+        }
+        resultIdsList.add(topKQueryResult.getIdsList().subList(i * topK, pos));
+        resultDistancesList.add(topKQueryResult.getDistancesList().subList(i * topK, pos));
+      }
     }
 
     SearchResponse searchResponse = new SearchResponse();

+ 1 - 1
src/main/java/io/milvus/client/SearchResponse.java

@@ -60,7 +60,7 @@ public class SearchResponse {
     return IntStream.range(0, numQueries)
         .mapToObj(
             i ->
-                LongStream.range(0, topK)
+                LongStream.range(0, resultIdsList.get(i).size())
                     .mapToObj(
                         j ->
                             new QueryResult(

+ 4 - 4
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -357,21 +357,21 @@ class MilvusClientTest {
 
   @org.junit.jupiter.api.Test
   void search() {
-    List<List<Float>> vectors = generateFloatVectors(100, dimension);
+    List<List<Float>> vectors = generateFloatVectors(size, dimension);
     vectors = vectors.stream().map(MilvusClientTest::normalizeVector).collect(Collectors.toList());
     InsertParam insertParam =
         new InsertParam.Builder(randomCollectionName).withFloatVectors(vectors).build();
     InsertResponse insertResponse = client.insert(insertParam);
     assertTrue(insertResponse.ok());
     List<Long> vectorIds = insertResponse.getVectorIds();
-    assertEquals(100, vectorIds.size());
+    assertEquals(size, vectorIds.size());
 
     assertTrue(client.flush(randomCollectionName).ok());
 
     final int searchSize = 5;
     List<List<Float>> vectorsToSearch = vectors.subList(0, searchSize);
 
-    final long topK = 200;
+    final long topK = 10;
     SearchParam searchParam =
         new SearchParam.Builder(randomCollectionName)
             .withFloatVectors(vectorsToSearch)
@@ -385,8 +385,8 @@ class MilvusClientTest {
     List<List<Float>> resultDistancesList = searchResponse.getResultDistancesList();
     assertEquals(searchSize, resultDistancesList.size());
     List<List<SearchResponse.QueryResult>> queryResultsList = searchResponse.getQueryResultsList();
-    System.out.println(queryResultsList.get(0).size());
     assertEquals(searchSize, queryResultsList.size());
+
     final double epsilon = 0.001;
     for (int i = 0; i < searchSize; i++) {
       SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);