Ver código fonte

Fix a defect of QueryReq (#1574)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 mês atrás
pai
commit
8b55241380

+ 0 - 1
examples/src/main/java/io/milvus/v2/ArrayFieldExample.java

@@ -142,7 +142,6 @@ public class ArrayFieldExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/BinaryVectorExample.java

@@ -114,7 +114,6 @@ public class BinaryVectorExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/FullTextSearchExample.java

@@ -143,7 +143,6 @@ public class FullTextSearchExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/HybridSearchExample.java

@@ -175,7 +175,6 @@ public class HybridSearchExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/Int8VectorExample.java

@@ -129,7 +129,6 @@ public class Int8VectorExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/IteratorExample.java

@@ -119,7 +119,6 @@ public class IteratorExample {
         // Check row count
         QueryResp queryResp = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/JsonFieldExample.java

@@ -147,7 +147,6 @@ public class JsonFieldExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/NullAndDefaultExample.java

@@ -169,7 +169,6 @@ public class NullAndDefaultExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/SimpleExample.java

@@ -70,7 +70,6 @@ public class SimpleExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(collectionName)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/SparseVectorExample.java

@@ -107,7 +107,6 @@ public class SparseVectorExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 0 - 1
examples/src/main/java/io/milvus/v2/TextMatchExample.java

@@ -159,7 +159,6 @@ public class TextMatchExample {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());

+ 3 - 5
sdk-core/src/main/java/io/milvus/v2/service/vector/VectorService.java

@@ -35,6 +35,7 @@ import io.milvus.v2.service.vector.request.*;
 import io.milvus.v2.service.vector.response.*;
 import io.milvus.v2.utils.DataUtils;
 import io.milvus.v2.utils.VectorUtils;
+import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -223,14 +224,11 @@ public class VectorService extends BaseService {
         String dbName = request.getDatabaseName();
         String collectionName = request.getCollectionName();
         String title = String.format("QueryRequest collectionName:%s, databaseName:%s", collectionName, dbName);
-        if (request.getFilter() == null && request.getIds() == null) {
-            throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "filter and ids can't be null at the same time");
-        } else if (request.getFilter() != null && request.getIds() != null) {
+        if (StringUtils.isNotEmpty(request.getFilter()) && CollectionUtils.isNotEmpty(request.getIds())) {
             throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "filter and ids can't be set at the same time");
         }
 
-
-        if (request.getIds() != null && request.getFilter() == null) {
+        if (CollectionUtils.isNotEmpty(request.getIds())) {
             DescribeCollectionResponse descResp = getCollectionInfo(blockingStub, dbName, collectionName, false);
             String primaryKeyName = "";
             List<FieldSchema> fields = descResp.getSchema().getFieldsList();

+ 2 - 1
sdk-core/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java

@@ -36,7 +36,8 @@ public class QueryReq {
     @Builder.Default
     private List<String> outputFields = Collections.singletonList("*");
     private List<Object> ids;
-    private String filter;
+    @Builder.Default
+    private String filter = "";
     @Builder.Default
     private ConsistencyLevel consistencyLevel = null;
     private long offset;

+ 1 - 1
sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java

@@ -58,7 +58,7 @@ public class VectorUtils {
             builder.setDbName(dbName);
         }
 
-        if (request.getFilter() != null && !request.getFilter().isEmpty()) {
+        if (StringUtils.isNotEmpty(request.getFilter())) {
             Map<String, Object> filterTemplateValues = request.getFilterTemplateValues();
             filterTemplateValues.forEach((key, value)->{
                 builder.putExprTemplateValues(key, deduceAndCreateTemplateValue(value));

+ 76 - 15
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -286,7 +286,6 @@ class MilvusClientV2DockerTest {
     private long getRowCount(String collectionName) {
         QueryResp queryResp = client.query(QueryReq.builder()
                 .collectionName(collectionName)
-                .filter("")
                 .outputFields(Collections.singletonList("count(*)"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());
@@ -519,21 +518,83 @@ class MilvusClientV2DockerTest {
             verifyOutput(row, entity);
         }
 
-        // query
-        QueryResp queryResp = client.query(QueryReq.builder()
-                .collectionName(randomCollectionName)
-                .filter("JSON_CONTAINS_ANY(json_field[\"flags\"], [4, 100])")
-                .build());
-        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
-        Assertions.assertEquals(6, queryResults.size());
+        {
+            // query with template
+            Map<String,Object> template = new HashMap<>();
+            template.put("id_arr", Arrays.asList(5, 6, 7));
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .filter("id in {id_arr}")
+                    .filterTemplateValues(template)
+                    .build());
+            List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+            Assertions.assertEquals(3, queryResults.size());
+        }
 
-        // test the withTimeout works well
-        client.withTimeout(1, TimeUnit.NANOSECONDS);
-        Assertions.assertThrows(MilvusClientException.class, ()->client.query(QueryReq.builder()
-                .collectionName(randomCollectionName)
-                .filter("JSON_CONTAINS_ANY(json_field[\"flags\"], [4, 100])")
-                .consistencyLevel(ConsistencyLevel.STRONG)
-                .build()));
+        {
+            // query with limit
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .limit(8)
+                    .build());
+            List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+            Assertions.assertEquals(8, queryResults.size());
+        }
+
+        {
+            // query with limit and filter
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .filter("id > 1")
+                    .limit(8)
+                    .build());
+            List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+            Assertions.assertEquals(8, queryResults.size());
+        }
+
+        {
+            // query with ids
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .ids(Arrays.asList(1, 5, 10))
+                    .build());
+            List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+            Assertions.assertEquals(3, queryResults.size());
+        }
+
+        {
+            // query error with 0 limit and empty filter
+            Assertions.assertThrows(MilvusClientException.class, () -> client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .build()));
+        }
+
+        {
+            // query error with ids and filter
+            Assertions.assertThrows(MilvusClientException.class, () -> client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .filter("id > 1")
+                    .ids(Arrays.asList(1, 3, 5))
+                    .build()));
+        }
+
+        {
+            // query timeout
+            QueryResp queryResp = client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .filter("JSON_CONTAINS_ANY(json_field[\"flags\"], [4, 100])")
+                    .build());
+            List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+            Assertions.assertEquals(6, queryResults.size());
+
+            // test the withTimeout works well
+            client.withTimeout(1, TimeUnit.NANOSECONDS);
+            Assertions.assertThrows(MilvusClientException.class, () -> client.query(QueryReq.builder()
+                    .collectionName(randomCollectionName)
+                    .filter("JSON_CONTAINS_ANY(json_field[\"flags\"], [4, 100])")
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .build()));
+        }
 
         client.withTimeout(0, TimeUnit.SECONDS);
         client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());