Browse Source

[cherry-pick] Fix bug of highLevel-get and highLevel-delete and Add test (#606)

* Fix bug of highLevel-delete (#584)

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>

* Fix bug of highLevel-get (#590)

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>

* Add test for highLevel-get and highLevel-delete

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>

---------

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>
xushuang.hu 1 year ago
parent
commit
94202ea742

+ 1 - 3
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -3002,14 +3002,12 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
             }
             }
 
 
             DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
             DescCollResponseWrapper wrapper = new DescCollResponseWrapper(descResp.getData());
-            FieldType primaryField = wrapper.getPrimaryField();
-
             if (CollectionUtils.isEmpty(requestParam.getOutputFields())) {
             if (CollectionUtils.isEmpty(requestParam.getOutputFields())) {
                 FieldType vectorField = wrapper.getVectorField();
                 FieldType vectorField = wrapper.getVectorField();
                 requestParam.getOutputFields().addAll(Lists.newArrayList(Constant.ALL_OUTPUT_FIELDS, vectorField.getName()));
                 requestParam.getOutputFields().addAll(Lists.newArrayList(Constant.ALL_OUTPUT_FIELDS, vectorField.getName()));
             }
             }
 
 
-            String expr = VectorUtils.convertPksExpr(requestParam.getPrimaryIds(), primaryField.getName());
+            String expr = VectorUtils.convertPksExpr(requestParam.getPrimaryIds(), wrapper);
             QueryParam queryParam = QueryParam.newBuilder()
             QueryParam queryParam = QueryParam.newBuilder()
                     .withCollectionName(requestParam.getCollectionName())
                     .withCollectionName(requestParam.getCollectionName())
                     .withExpr(expr)
                     .withExpr(expr)

+ 4 - 1
src/main/java/io/milvus/common/utils/VectorUtils.java

@@ -17,10 +17,13 @@ public class VectorUtils {
             FieldType primaryField = optional.get();
             FieldType primaryField = optional.get();
             switch (primaryField.getDataType()) {
             switch (primaryField.getDataType()) {
                 case Int64:
                 case Int64:
-                case VarChar:
                     List<String> primaryStringIds = primaryIds.stream().map(String::valueOf).collect(Collectors.toList());
                     List<String> primaryStringIds = primaryIds.stream().map(String::valueOf).collect(Collectors.toList());
                     expr = convertPksExpr(primaryStringIds, primaryField.getName());
                     expr = convertPksExpr(primaryStringIds, primaryField.getName());
                     break;
                     break;
+                case VarChar:
+                    List<String> primaryVarcharIds = primaryIds.stream().map(primaryId -> String.format("\"%s\"", primaryId)).collect(Collectors.toList());
+                    expr = convertPksExpr(primaryVarcharIds, primaryField.getName());
+                    break;
                 default:
                 default:
                     throw new ParamException("The primary key is not of type int64 or varchar, and the current operation is not supported.");
                     throw new ParamException("The primary key is not of type int64 or varchar, and the current operation is not supported.");
             }
             }

+ 177 - 0
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -29,6 +29,10 @@ import io.milvus.param.dml.InsertParam;
 import io.milvus.param.dml.QueryParam;
 import io.milvus.param.dml.QueryParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.param.dml.UpsertParam;
 import io.milvus.param.dml.UpsertParam;
+import io.milvus.param.highlevel.dml.DeleteIdsParam;
+import io.milvus.param.highlevel.dml.GetIdsParam;
+import io.milvus.param.highlevel.dml.response.DeleteResponse;
+import io.milvus.param.highlevel.dml.response.GetResponse;
 import io.milvus.param.index.CreateIndexParam;
 import io.milvus.param.index.CreateIndexParam;
 import io.milvus.param.index.DescribeIndexParam;
 import io.milvus.param.index.DescribeIndexParam;
 import io.milvus.param.index.DropIndexParam;
 import io.milvus.param.index.DropIndexParam;
@@ -1685,4 +1689,177 @@ class MilvusClientDockerTest {
                 .build());
                 .build());
         Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
         Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
     }
     }
+
+    @Test
+    void testHighLevelGet() {
+        // collection schema
+        String field1Name = "id_field";
+        String field2Name = "vector_field";
+        FieldType int64PrimaryField = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.Int64)
+                .withName(field1Name)
+                .build();
+
+        FieldType varcharPrimaryField = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withDataType(DataType.VarChar)
+                .withName(field1Name)
+                .withMaxLength(128)
+                .build();
+
+        FieldType vectorField = FieldType.newBuilder()
+                .withDataType(DataType.FloatVector)
+                .withName(field2Name)
+                .withDimension(dimension)
+                .build();
+
+        testCollectionHighLevelGet(int64PrimaryField, vectorField);
+        testCollectionHighLevelGet(varcharPrimaryField, vectorField);
+    }
+
+    @Test
+    void testHighLevelDelete() {
+        // collection schema
+        String field1Name = "id_field";
+        String field2Name = "vector_field";
+        FieldType int64PrimaryField = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.Int64)
+                .withName(field1Name)
+                .build();
+
+        FieldType varcharPrimaryField = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withDataType(DataType.VarChar)
+                .withName(field1Name)
+                .withMaxLength(128)
+                .build();
+
+        FieldType vectorField = FieldType.newBuilder()
+                .withDataType(DataType.FloatVector)
+                .withName(field2Name)
+                .withDimension(dimension)
+                .build();
+
+        testCollectionHighLevelDelete(int64PrimaryField, vectorField);
+        testCollectionHighLevelDelete(varcharPrimaryField, vectorField);
+    }
+
+    void testCollectionHighLevelGet(FieldType primaryField, FieldType vectorField) {
+        // create collection
+        String randomCollectionName = generator.generate(10);
+        highLevelCreateCollection(primaryField, vectorField, randomCollectionName);
+
+        // insert data
+        List<String> primaryIds = new ArrayList<>();
+        int rowCount = 10;
+        List<JSONObject> rows = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            JSONObject row = new JSONObject();
+            row.put(primaryField.getName(), primaryField.getDataType() == DataType.Int64 ? i : String.valueOf(i));
+            row.put(vectorField.getName(), generateFloatVectors(1).get(0));
+            rows.add(row);
+            primaryIds.add(String.valueOf(i));
+        }
+
+        InsertParam insertRowParam = InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withRows(rows)
+                .build();
+
+        R<MutationResult> insertRowResp = client.insert(insertRowParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), insertRowResp.getStatus().intValue());
+
+        testHighLevelGet(randomCollectionName, primaryIds);
+        client.dropCollection(DropCollectionParam.newBuilder().withCollectionName(randomCollectionName).build());
+    }
+
+    private static void highLevelCreateCollection(FieldType primaryField, FieldType vectorField, String randomCollectionName) {
+        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withDescription("test")
+                .addFieldType(primaryField)
+                .addFieldType(vectorField)
+                .build();
+
+        R<RpcStatus> createR = client.createCollection(createParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        // create index
+        CreateIndexParam indexParam = CreateIndexParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFieldName(vectorField.getName())
+                .withIndexName("abv")
+                .withIndexType(IndexType.FLAT)
+                .withMetricType(MetricType.L2)
+                .withExtraParam("{}")
+                .build();
+
+        R<RpcStatus> createIndexR = client.createIndex(indexParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
+
+        // load collection
+        R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
+    }
+
+    void testCollectionHighLevelDelete(FieldType primaryField, FieldType vectorField) {
+        // create collection & buildIndex & loadCollection
+        String randomCollectionName = generator.generate(10);
+        highLevelCreateCollection(primaryField, vectorField, randomCollectionName);
+
+        // insert data
+        List<String> primaryIds = new ArrayList<>();
+        int rowCount = 10;
+        List<JSONObject> rows = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            JSONObject row = new JSONObject();
+            row.put(primaryField.getName(), primaryField.getDataType() == DataType.Int64 ? i : String.valueOf(i));
+            row.put(vectorField.getName(), generateFloatVectors(1).get(0));
+            rows.add(row);
+            primaryIds.add(String.valueOf(i));
+        }
+
+        InsertParam insertRowParam = InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withRows(rows)
+                .build();
+
+        R<MutationResult> insertRowResp = client.insert(insertRowParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), insertRowResp.getStatus().intValue());
+
+        // high level delete
+        testHighLevelDelete(randomCollectionName, primaryIds);
+        client.dropCollection(DropCollectionParam.newBuilder().withCollectionName(randomCollectionName).build());
+    }
+
+    private static void testHighLevelGet(String collectionName, List primaryIds) {
+        GetIdsParam getIdsParam = GetIdsParam.newBuilder()
+                .withCollectionName(collectionName)
+                .withPrimaryIds(primaryIds)
+                .build();
+
+        R<GetResponse> getResponseR = client.get(getIdsParam);
+        String outPutStr = String.format("collectionName:%s, primaryIds:%s, getResponseR:%s", collectionName, primaryIds, getResponseR.getData());
+        System.out.println(outPutStr);
+        Assertions.assertEquals(R.Status.Success.getCode(), getResponseR.getStatus().intValue());
+    }
+
+    private static void testHighLevelDelete(String collectionName, List primaryIds) {
+        DeleteIdsParam deleteIdsParam = DeleteIdsParam.newBuilder()
+                .withCollectionName(collectionName)
+                .withPrimaryIds(primaryIds)
+                .build();
+
+        R<DeleteResponse> deleteResponseR = client.delete(deleteIdsParam);
+        String outPutStr = String.format("collectionName:%s, primaryIds:%s, deleteResponseR:%s", collectionName, primaryIds, deleteResponseR);
+        System.out.println(outPutStr);
+        Assertions.assertEquals(R.Status.Success.getCode(), deleteResponseR.getStatus().intValue());
+        Assertions.assertEquals(primaryIds.size(), deleteResponseR.getData().getDeleteIds().size());
+    }
 }
 }

+ 1 - 1
tests/milvustest/src/test/java/com/zilliz/milvustest/tls/TLSTest.java

@@ -137,7 +137,7 @@ public class TLSTest {
                 .search(SearchSimpleParam.newBuilder()
                 .search(SearchSimpleParam.newBuilder()
                 .withCollectionName(collectionName)
                 .withCollectionName(collectionName)
                 .withOffset(0L)
                 .withOffset(0L)
-                .withLimit(100)
+                .withLimit(100L)
                 .withFilter("book_id>5000")
                 .withFilter("book_id>5000")
                 .withVectors(search_vectors)
                 .withVectors(search_vectors)
                 .withOutputFields(Lists.newArrayList("*"))
                 .withOutputFields(Lists.newArrayList("*"))