Browse Source

Avoid exception when search result is empty (#1458)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 day ago
parent
commit
b7e25c11bd

+ 11 - 2
sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -237,6 +237,14 @@ public class SearchResultsWrapper extends RowRecordWrapper {
         return idScores;
     }
 
+    /**
+     * Gets how many nq are searched.
+     * @return how many nq are searched
+     */
+    public long getNumQueries() {
+        return results.getNumQueries();
+    }
+
     @Getter
     private static final class Position {
         private final long offset;
@@ -250,11 +258,12 @@ public class SearchResultsWrapper extends RowRecordWrapper {
     private Position getOffsetByIndex(int indexOfTarget) {
         List<Long> kList = results.getTopksList();
 
-        // if the server didn't return separate topK, use same topK value
+        // if the server didn't return separate topK, use same topK value "0"
+        // will return an empty result for each nq instead of throwing an exception
         if (kList.isEmpty()) {
             kList = new ArrayList<>();
             for (long i = 0; i < results.getNumQueries(); ++i) {
-                kList.add(results.getTopK());
+                kList.add(0L);
             }
         }
 

+ 78 - 51
sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -27,6 +27,7 @@ import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.common.utils.Float16Utils;
 import io.milvus.common.utils.GTsDict;
 import io.milvus.common.utils.JsonUtils;
+import io.milvus.exception.ParamException;
 import io.milvus.grpc.*;
 import io.milvus.orm.iterator.QueryIterator;
 import io.milvus.orm.iterator.SearchIterator;
@@ -63,6 +64,7 @@ import java.nio.ByteBuffer;
 import java.util.*;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
 
 @Testcontainers(disabledWithoutDocker = true)
 class MilvusClientDockerTest {
@@ -1345,18 +1347,6 @@ class MilvusClientDockerTest {
         R<RpcStatus> createR = client.createCollection(createParam);
         Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
 
-        // insert data to multiple vector fields
-        int rowCount = 10000;
-        List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
-
-        InsertParam insertParam = InsertParam.newBuilder()
-                .withCollectionName(randomCollectionName)
-                .withFields(fields)
-                .build();
-
-        R<MutationResult> insertR = client.insert(insertParam);
-        Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
-
         // create indexes on multiple vector fields
         CreateIndexParam indexParam = CreateIndexParam.newBuilder()
                 .withCollectionName(randomCollectionName)
@@ -1397,53 +1387,86 @@ class MilvusClientDockerTest {
                 .build());
         Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
 
-        // search on multiple vector fields
-        AnnSearchParam param1 = AnnSearchParam.newBuilder()
-                .withVectorFieldName(DataType.FloatVector.name())
-                .withFloatVectors(utils.generateFloatVectors(1))
-                .withMetricType(MetricType.COSINE)
-                .withParams("{\"nprobe\": 32}")
-                .withLimit(10L)
-                .build();
-
-        AnnSearchParam param2 = AnnSearchParam.newBuilder()
-                .withVectorFieldName(DataType.BinaryVector.name())
-                .withBinaryVectors(utils.generateBinaryVectors(1))
-                .withMetricType(MetricType.HAMMING)
-                .withParams("{}")
-                .withLimit(5L)
-                .build();
-
-        AnnSearchParam param3 = AnnSearchParam.newBuilder()
-                .withVectorFieldName(DataType.SparseFloatVector.name())
-                .withSparseFloatVectors(utils.generateSparseVectors(1))
-                .withMetricType(MetricType.IP)
-                .withParams("{\"drop_ratio_search\":0.2}")
-                .withLimit(7L)
-                .build();
+        // prepare sub requests
+        int nq = 5;
+        long topk = 10L;
+        Function<Integer, HybridSearchParam> genRequestFunc =
+                sparseCount -> {
+                    AnnSearchParam param1 = AnnSearchParam.newBuilder()
+                            .withVectorFieldName(DataType.FloatVector.name())
+                            .withFloatVectors(utils.generateFloatVectors(nq))
+                            .withMetricType(MetricType.COSINE)
+                            .withParams("{\"nprobe\": 32}")
+                            .withLimit(15L)
+                            .build();
+
+                    AnnSearchParam param2 = AnnSearchParam.newBuilder()
+                            .withVectorFieldName(DataType.BinaryVector.name())
+                            .withBinaryVectors(utils.generateBinaryVectors(nq))
+                            .withMetricType(MetricType.HAMMING)
+                            .withParams("{}")
+                            .withLimit(5L)
+                            .build();
+
+                    List<SortedMap<Long, Float>> sparseVEctors = sparseCount > 0 ?
+                            utils.generateSparseVectors(sparseCount) : new ArrayList<>();
+                    AnnSearchParam param3 = AnnSearchParam.newBuilder()
+                            .withVectorFieldName(DataType.SparseFloatVector.name())
+                            .withSparseFloatVectors(sparseVEctors)
+                            .withMetricType(MetricType.IP)
+                            .withParams("{\"drop_ratio_search\":0.2}")
+                            .withLimit(7L)
+                            .build();
+
+                    // search with an empty nq, return error
+                    return HybridSearchParam.newBuilder()
+                            .withCollectionName(randomCollectionName)
+                            .addOutField(DataType.SparseFloatVector.name())
+                            .addSearchRequest(param1)
+                            .addSearchRequest(param2)
+                            .addSearchRequest(param3)
+                            .withLimit(topk)
+                            .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                            .withRanker(WeightedRanker.newBuilder()
+                                    .withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
+                                    .build())
+                            .withOutFields(Collections.singletonList("*"))
+                            .build();
+                };
+
+        // search with an empty nq, return error
+        Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(0));
+
+        // unequal nq, return error
+        Assertions.assertThrows(ParamException.class, ()->genRequestFunc.apply(1));
+
+        // search on empty collection, no result returned
+        R<SearchResults> searchR = client.hybridSearch(genRequestFunc.apply(nq));
+        Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        for (int i = 0; i < results.getNumQueries(); ++i) {
+            List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
+            Assertions.assertTrue(scores.isEmpty());
+        }
 
-        HybridSearchParam searchParam = HybridSearchParam.newBuilder()
+        // insert data to multiple vector fields
+        int rowCount = 10000;
+        List<InsertParam.Field> fields = generateColumnsData(schema, rowCount, 0);
+        InsertParam insertParam = InsertParam.newBuilder()
                 .withCollectionName(randomCollectionName)
-                .addOutField(DataType.SparseFloatVector.name())
-                .addSearchRequest(param1)
-                .addSearchRequest(param2)
-                .addSearchRequest(param3)
-                .withLimit(3L)
-                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
-                .withRanker(WeightedRanker.newBuilder()
-                        .withWeights(Lists.newArrayList(0.5f, 0.5f, 1.0f))
-                        .build())
-                .withOutFields(Collections.singletonList("*"))
+                .withFields(fields)
                 .build();
+        R<MutationResult> insertR = client.insert(insertParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
 
-        R<SearchResults> searchR = client.hybridSearch(searchParam);
+        // search on multiple vector fields
+        searchR = client.hybridSearch(genRequestFunc.apply(nq));
         Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
 
-        // print search result
-        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        // check search result
+        results = new SearchResultsWrapper(searchR.getData().getResults());
         List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
         for (SearchResultsWrapper.IDScore score : scores) {
-            System.out.println(score);
             Object id = score.get("id");
             Assertions.assertInstanceOf(Long.class, id);
             Object fv = score.get(DataType.FloatVector.name());
@@ -1457,6 +1480,10 @@ class MilvusClientDockerTest {
             Object sv = score.get(DataType.SparseFloatVector.name());
             Assertions.assertInstanceOf(SortedMap.class, sv);
         }
+        for (int i = 0; i < results.getNumQueries(); ++i) {
+            scores = results.getIDScore(i);
+            Assertions.assertEquals(topk, scores.size());
+        }
 
         // drop collection
         DropCollectionParam dropParam = DropCollectionParam.newBuilder()

+ 4 - 0
sdk-core/src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -2969,6 +2969,8 @@ class MilvusServiceClientTest {
         String fieldName = "test";
         SearchResultData results = SearchResultData.newBuilder()
                 .setTopK(topK)
+                .addTopks(topK)
+                .addTopks(topK) // numQueries=2, the topks list must have 2 elements
                 .setNumQueries(numQueries)
                 .setIds(IDs.newBuilder()
                         .setIntId(LongArray.newBuilder()
@@ -2996,6 +2998,8 @@ class MilvusServiceClientTest {
         // for string id
         results = SearchResultData.newBuilder()
                 .setTopK(topK)
+                .addTopks(topK)
+                .addTopks(topK) // numQueries=2, the topks list must have 2 elements
                 .setNumQueries(numQueries)
                 .setIds(IDs.newBuilder()
                         .setStrId(StringArray.newBuilder()

+ 61 - 39
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -33,6 +33,7 @@ import io.milvus.orm.iterator.QueryIterator;
 import io.milvus.orm.iterator.SearchIterator;
 import io.milvus.orm.iterator.SearchIteratorV2;
 import io.milvus.param.Constant;
+import io.milvus.param.dml.HybridSearchParam;
 import io.milvus.pool.MilvusClientV2Pool;
 import io.milvus.pool.PoolConfig;
 import io.milvus.response.QueryResultsWrapper;
@@ -1010,6 +1011,63 @@ class MilvusClientV2DockerTest {
         Assertions.assertEquals(16, descResp.getFieldNames().size());
         Assertions.assertEquals(3, descResp.getVectorFieldNames().size());
 
+        // prepare sub requests
+        int nq = 5;
+        int topk = 10;
+        Function<Integer, HybridSearchReq> genRequestFunc =
+                sparseCount -> {
+                    List<BaseVector> floatVectors = new ArrayList<>();
+                    List<BaseVector> binaryVectors = new ArrayList<>();
+                    List<BaseVector> sparseVectors = new ArrayList<>();
+                    for (int i = 0; i < nq; i++) {
+                        floatVectors.add(new FloatVec(utils.generateFloatVector()));
+                        binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
+                    }
+                    for (int i = 0; i < sparseCount; i++) {
+                        sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
+                    }
+
+                    List<AnnSearchReq> searchRequests = new ArrayList<>();
+                    searchRequests.add(AnnSearchReq.builder()
+                            .vectorFieldName("float_vector")
+                            .vectors(floatVectors)
+                            .params("{\"nprobe\": 10}")
+                            .limit(15)
+                            .build());
+                    searchRequests.add(AnnSearchReq.builder()
+                            .vectorFieldName("binary_vector")
+                            .vectors(binaryVectors)
+                            .limit(5)
+                            .build());
+                    searchRequests.add(AnnSearchReq.builder()
+                            .vectorFieldName("sparse_vector")
+                            .vectors(sparseVectors)
+                            .limit(7)
+                            .build());
+
+                    return HybridSearchReq.builder()
+                            .collectionName(randomCollectionName)
+                            .searchRequests(searchRequests)
+                            .ranker(new RRFRanker(20))
+                            .limit(topk)
+                            .consistencyLevel(ConsistencyLevel.BOUNDED)
+                            .build();
+        };
+
+        // search with an empty nq, return error
+        Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0)));
+
+        // unequal nq, return error
+        Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1)));
+
+        // search on empty collection, no result returned
+        SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq));
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        Assertions.assertEquals(nq, searchResults.size());
+        for (List<SearchResp.SearchResult> result : searchResults) {
+            Assertions.assertTrue(result.isEmpty());
+        }
+
         // insert rows
         long count = 10000;
         List<JsonObject> data = generateRandomData(collectionSchema, count);
@@ -1023,45 +1081,9 @@ class MilvusClientV2DockerTest {
         long rowCount = getRowCount(randomCollectionName);
         Assertions.assertEquals(count, rowCount);
 
-        // hybrid search in collection
-        int nq = 5;
-        int topk = 10;
-        List<BaseVector> floatVectors = new ArrayList<>();
-        List<BaseVector> binaryVectors = new ArrayList<>();
-        List<BaseVector> sparseVectors = new ArrayList<>();
-        for (int i = 0; i < nq; i++) {
-            floatVectors.add(new FloatVec(utils.generateFloatVector()));
-            binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
-            sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
-        }
-
-        List<AnnSearchReq> searchRequests = new ArrayList<>();
-        searchRequests.add(AnnSearchReq.builder()
-                .vectorFieldName("float_vector")
-                .vectors(floatVectors)
-                .params("{\"nprobe\": 10}")
-                .limit(10)
-                .build());
-        searchRequests.add(AnnSearchReq.builder()
-                .vectorFieldName("binary_vector")
-                .vectors(binaryVectors)
-                .limit(50)
-                .build());
-        searchRequests.add(AnnSearchReq.builder()
-                .vectorFieldName("sparse_vector")
-                .vectors(sparseVectors)
-                .limit(100)
-                .build());
-
-        HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
-                .collectionName(randomCollectionName)
-                .searchRequests(searchRequests)
-                .ranker(new RRFRanker(20))
-                .limit(topk)
-                .consistencyLevel(ConsistencyLevel.BOUNDED)
-                .build();
-        SearchResp searchResp = client.hybridSearch(hybridSearchReq);
-        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        // search again, there are results
+        searchResp = client.hybridSearch(genRequestFunc.apply(nq));
+        searchResults = searchResp.getSearchResults();
         Assertions.assertEquals(nq, searchResults.size());
         for (int i = 0; i < nq; i++) {
             List<SearchResp.SearchResult> results = searchResults.get(i);