Browse Source

Fix drop index bug (#328)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
5c272d5325

+ 12 - 0
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -1054,9 +1054,21 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
         logInfo(requestParam.toString());
 
         try {
+            DescribeIndexRequest describeIndexRequest = DescribeIndexRequest.newBuilder()
+                    .setCollectionName(requestParam.getCollectionName())
+                    .setIndexName(requestParam.getIndexName())
+                    .build();
+
+            DescribeIndexResponse descResp = blockingStub().describeIndex(describeIndexRequest);
+            if (descResp.getStatus().getErrorCode() != ErrorCode.Success || descResp.getIndexDescriptionsCount() == 0) {
+                logError("Index doesn't exist:\n{}", requestParam.getIndexName());
+                return R.failed(R.Status.IndexNotExist, "Index doesn't exist");
+            }
+
             DropIndexRequest dropIndexRequest = DropIndexRequest.newBuilder()
                     .setCollectionName(requestParam.getCollectionName())
                     .setIndexName(requestParam.getIndexName())
+                    .setFieldName(descResp.getIndexDescriptions(0).getFieldName())
                     .build();
 
             Status response = blockingStub().dropIndex(dropIndexRequest);

+ 139 - 143
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -341,128 +341,136 @@ class MilvusClientDockerTest {
         DescIndexResponseWrapper indexDesc = new DescIndexResponseWrapper(descIndexR.getData());
         System.out.println("Index description: " + indexDesc.toString());
 
-//        // load collection
-//        R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
-//                .withCollectionName(randomCollectionName)
-//                .build());
-//        assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
-//
-//        // show collections
-//        R<ShowCollectionsResponse> showR = client.showCollections(ShowCollectionsParam.newBuilder()
-//                .addCollectionName(randomCollectionName)
-//                .build());
-//        assertEquals(R.Status.Success.getCode(), showR.getStatus().intValue());
-//        ShowCollResponseWrapper info = new ShowCollResponseWrapper(showR.getData());
-//        System.out.println("Collection info: " + info.toString());
-//
-//        // show partitions
-//        R<ShowPartitionsResponse> showPartR = client.showPartitions(ShowPartitionsParam.newBuilder()
-//                .withCollectionName(randomCollectionName)
-//                .addPartitionName("_default") // each collection has a '_default' partition
-//                .build());
-//        assertEquals(R.Status.Success.getCode(), showPartR.getStatus().intValue());
-//        ShowPartResponseWrapper infoPart = new ShowPartResponseWrapper(showPartR.getData());
-//        System.out.println("Partition info: " + infoPart.toString());
-//
-//        // query vectors to verify
-//        List<Long> queryIDs = new ArrayList<>();
-//        List<Double> compareWeights = new ArrayList<>();
-//        int nq = 5;
-//        Random ran = new Random();
-//        int randomIndex = ran.nextInt(rowCount - nq);
-//        for (int i = randomIndex; i < randomIndex + nq; ++i) {
-//            queryIDs.add(ids.get(i));
-//            compareWeights.add(weights.get(i));
-//        }
-//        String expr = field1Name + " in " + queryIDs.toString();
-//        List<String> outputFields = Arrays.asList(field1Name, field2Name, field3Name, field4Name, field5Name);
-//        QueryParam queryParam = QueryParam.newBuilder()
-//                .withCollectionName(randomCollectionName)
-//                .withExpr(expr)
-//                .withOutFields(outputFields)
-//                .build();
-//
-//        R<QueryResults> queryR = client.query(queryParam);
-//        assertEquals(R.Status.Success.getCode(), queryR.getStatus().intValue());
-//
-//        // verify query result
-//        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(queryR.getData());
-//        for (String fieldName : outputFields) {
-//            FieldDataWrapper wrapper = queryResultsWrapper.getFieldWrapper(fieldName);
-//            System.out.println("Query data of " + fieldName + ", row count: " + wrapper.getRowCount());
-//            System.out.println(wrapper.getFieldData());
-//            assertEquals(nq, wrapper.getFieldData().size());
-//
-//            if (fieldName.compareTo(field1Name) == 0) {
-//                List<?> out = queryResultsWrapper.getFieldWrapper(field1Name).getFieldData();
-//                assertEquals(nq, out.size());
-//                for (Object o : out) {
-//                    long id = (Long) o;
-//                    assertTrue(queryIDs.contains(id));
-//                }
-//            }
-//        }
-//
-//        // Note: the query() return vectors are not in same sequence to the input
-//        // here we cannot compare vector one by one
-//        // the boolean also cannot be compared
-//        if (outputFields.contains(field2Name)) {
-//            assertTrue(queryResultsWrapper.getFieldWrapper(field2Name).isVectorField());
-//            List<?> out = queryResultsWrapper.getFieldWrapper(field2Name).getFieldData();
-//            assertEquals(nq, out.size());
-//        }
-//
-//        if (outputFields.contains(field3Name)) {
-//            List<?> out = queryResultsWrapper.getFieldWrapper(field3Name).getFieldData();
-//            assertEquals(nq, out.size());
-//        }
-//
-//        if (outputFields.contains(field4Name)) {
-//            List<?> out = queryResultsWrapper.getFieldWrapper(field4Name).getFieldData();
-//            assertEquals(nq, out.size());
-//            for (Object o : out) {
-//                double d = (Double) o;
-//                assertTrue(compareWeights.contains(d));
-//            }
-//        }
-//
-//
-//        // pick some vectors to search
-//        List<Long> targetVectorIDs = new ArrayList<>();
-//        List<List<Float>> targetVectors = new ArrayList<>();
-//        for (int i = randomIndex; i < randomIndex + nq; ++i) {
-//            targetVectorIDs.add(ids.get(i));
-//            targetVectors.add(vectors.get(i));
-//        }
-//
-//        int topK = 5;
-//        SearchParam searchParam = SearchParam.newBuilder()
-//                .withCollectionName(randomCollectionName)
-//                .withMetricType(MetricType.L2)
-//                .withTopK(topK)
-//                .withVectors(targetVectors)
-//                .withVectorFieldName(field2Name)
-//                .withParams("{\"nprobe\":8}")
-//                .addOutField(field4Name)
-//                .build();
-//
-//        R<SearchResults> searchR = client.search(searchParam);
-////        System.out.println(searchR);
-//        assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
-//
-//        // verify the search result
-//        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
-//        for (int i = 0; i < targetVectors.size(); ++i) {
-//            List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
-//            System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
-//            System.out.println(scores);
-//            assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
-//        }
-//
-//        List<?> fieldData = results.getFieldData(field4Name, 0);
-//        assertEquals(topK, fieldData.size());
-//        fieldData = results.getFieldData(field4Name, nq - 1);
-//        assertEquals(topK, fieldData.size());
+        // load collection
+        R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+        assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
+
+        // show collections
+        R<ShowCollectionsResponse> showR = client.showCollections(ShowCollectionsParam.newBuilder()
+                .addCollectionName(randomCollectionName)
+                .build());
+        assertEquals(R.Status.Success.getCode(), showR.getStatus().intValue());
+        ShowCollResponseWrapper info = new ShowCollResponseWrapper(showR.getData());
+        System.out.println("Collection info: " + info.toString());
+
+        // show partitions
+        R<ShowPartitionsResponse> showPartR = client.showPartitions(ShowPartitionsParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .addPartitionName("_default") // each collection has a '_default' partition
+                .build());
+        assertEquals(R.Status.Success.getCode(), showPartR.getStatus().intValue());
+        ShowPartResponseWrapper infoPart = new ShowPartResponseWrapper(showPartR.getData());
+        System.out.println("Partition info: " + infoPart.toString());
+
+        // query vectors to verify
+        List<Long> queryIDs = new ArrayList<>();
+        List<Double> compareWeights = new ArrayList<>();
+        int nq = 5;
+        Random ran = new Random();
+        int randomIndex = ran.nextInt(rowCount - nq);
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            queryIDs.add(ids.get(i));
+            compareWeights.add(weights.get(i));
+        }
+        String expr = field1Name + " in " + queryIDs.toString();
+        List<String> outputFields = Arrays.asList(field1Name, field2Name, field3Name, field4Name, field5Name);
+        QueryParam queryParam = QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr(expr)
+                .withOutFields(outputFields)
+                .build();
+
+        R<QueryResults> queryR = client.query(queryParam);
+        assertEquals(R.Status.Success.getCode(), queryR.getStatus().intValue());
+
+        // verify query result
+        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(queryR.getData());
+        for (String fieldName : outputFields) {
+            FieldDataWrapper wrapper = queryResultsWrapper.getFieldWrapper(fieldName);
+            System.out.println("Query data of " + fieldName + ", row count: " + wrapper.getRowCount());
+            System.out.println(wrapper.getFieldData());
+            assertEquals(nq, wrapper.getFieldData().size());
+
+            if (fieldName.compareTo(field1Name) == 0) {
+                List<?> out = queryResultsWrapper.getFieldWrapper(field1Name).getFieldData();
+                assertEquals(nq, out.size());
+                for (Object o : out) {
+                    long id = (Long) o;
+                    assertTrue(queryIDs.contains(id));
+                }
+            }
+        }
+
+        // Note: the query() return vectors are not in same sequence to the input
+        // here we cannot compare vector one by one
+        // the boolean also cannot be compared
+        if (outputFields.contains(field2Name)) {
+            assertTrue(queryResultsWrapper.getFieldWrapper(field2Name).isVectorField());
+            List<?> out = queryResultsWrapper.getFieldWrapper(field2Name).getFieldData();
+            assertEquals(nq, out.size());
+        }
+
+        if (outputFields.contains(field3Name)) {
+            List<?> out = queryResultsWrapper.getFieldWrapper(field3Name).getFieldData();
+            assertEquals(nq, out.size());
+        }
+
+        if (outputFields.contains(field4Name)) {
+            List<?> out = queryResultsWrapper.getFieldWrapper(field4Name).getFieldData();
+            assertEquals(nq, out.size());
+            for (Object o : out) {
+                double d = (Double) o;
+                assertTrue(compareWeights.contains(d));
+            }
+        }
+
+
+        // pick some vectors to search
+        List<Long> targetVectorIDs = new ArrayList<>();
+        List<List<Float>> targetVectors = new ArrayList<>();
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            targetVectorIDs.add(ids.get(i));
+            targetVectors.add(vectors.get(i));
+        }
+
+        int topK = 5;
+        SearchParam searchParam = SearchParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withMetricType(MetricType.L2)
+                .withTopK(topK)
+                .withVectors(targetVectors)
+                .withVectorFieldName(field2Name)
+                .withParams("{\"nprobe\":8}")
+                .addOutField(field4Name)
+                .build();
+
+        R<SearchResults> searchR = client.search(searchParam);
+//        System.out.println(searchR);
+        assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+
+        // verify the search result
+        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        for (int i = 0; i < targetVectors.size(); ++i) {
+            List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
+            System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
+            System.out.println(scores);
+            assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
+        }
+
+        List<?> fieldData = results.getFieldData(field4Name, 0);
+        assertEquals(topK, fieldData.size());
+        fieldData = results.getFieldData(field4Name, nq - 1);
+        assertEquals(topK, fieldData.size());
+
+        // drop index
+        DropIndexParam dropIndexParam = DropIndexParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withIndexName(indexParam.getIndexName())
+                .build();
+        R<RpcStatus> dropIndexR = client.dropIndex(dropIndexParam);
+        assertEquals(R.Status.Success.getCode(), dropIndexR.getStatus().intValue());
 
         // drop collection
         DropCollectionParam dropParam = DropCollectionParam.newBuilder()
@@ -779,27 +787,15 @@ class MilvusClientDockerTest {
 
         R<RpcStatus> createIndexR = client.createIndex(indexParam);
 
-        client.getIndexState(GetIndexStateParam.newBuilder().withCollectionName(collectionName)
-                .withIndexName(indexParam.getIndexName()).build());
-
-//        R<RpcStatus> kk = client.dropIndex(DropIndexParam.newBuilder()
-//                .withCollectionName(collectionName)
-//                .withFieldName(field2Name)
-//                .withIndexName("xxx")
-//                .build());
-//
-//        indexParam = CreateIndexParam.newBuilder()
-//                .withCollectionName(collectionName)
-//                .withFieldName(field2Name)
-//                .withIndexName("xxx")
-//                .withIndexType(IndexType.IVF_FLAT)
-//                .withMetricType(MetricType.IP)
-//                .withExtraParam("{\"nlist\":256}")
-//                .withSyncMode(Boolean.TRUE)
-//                .withSyncWaitingInterval(500L)
-//                .withSyncWaitingTimeout(30L)
-//                .build();
-//        createIndexR = client.createIndex(indexParam);
+        client.getIndexState(GetIndexStateParam.newBuilder()
+                .withCollectionName(collectionName)
+                .withIndexName(indexParam.getIndexName())
+                .build());
+
+        R<RpcStatus> dropIndexR = client.dropIndex(DropIndexParam.newBuilder()
+                .withCollectionName(collectionName)
+                .withIndexName(indexParam.getIndexName())
+                .build());
 
         client.dropCollection(DropCollectionParam.newBuilder().withCollectionName(collectionName).build());
     }

+ 26 - 1
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -1331,12 +1331,37 @@ class MilvusServiceClientTest {
 
     @Test
     void dropIndex() {
+        // start mock server
+        MockMilvusServer server = startServer();
+        MilvusServiceClient client = startClient();
+
         DropIndexParam param = DropIndexParam.newBuilder()
                 .withCollectionName("collection1")
                 .withIndexName("idx")
                 .build();
 
-        testFuncByName("dropIndex", param);
+        // test return ok with correct input
+        mockServerImpl.setDescribeIndexResponse(DescribeIndexResponse.newBuilder()
+                .addIndexDescriptions(IndexDescription.newBuilder()
+                        .setIndexName(param.getIndexName())
+                        .setFieldName("fff")
+                        .build())
+                .build());
+
+        R<RpcStatus> resp = client.dropIndex(param);
+        assertEquals(R.Status.Success.getCode(), resp.getStatus());
+
+        // stop mock server
+        server.stop();
+
+        // test return error without server
+        resp = client.dropIndex(param);
+        assertNotEquals(R.Status.Success.getCode(), resp.getStatus());
+
+        // test return error when client channel is shutdown
+        client.close();
+        resp = client.dropIndex(param);
+        assertEquals(R.Status.ClientNotConnected.getCode(), resp.getStatus());
     }
 
     @Test