Jelajahi Sumber

Add sync parameter for loadCollection/loadPartitions/createIndex (#1332)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 bulan lalu
induk
melakukan
093abe3b11

+ 27 - 73
sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java

@@ -98,12 +98,16 @@ public class CollectionService extends BaseService {
         CreateIndexReq createIndexReq = CreateIndexReq.builder()
                         .indexParams(Collections.singletonList(indexParam))
                         .collectionName(request.getCollectionName())
+                        .sync(false)
                         .build();
         indexService.createIndex(blockingStub, createIndexReq);
         //load collection, set async to true since no need to wait loading progress
         try {
             //TimeUnit.MILLISECONDS.sleep(1000);
-            loadCollection(blockingStub, LoadCollectionReq.builder().async(true).collectionName(request.getCollectionName()).build());
+            loadCollection(blockingStub, LoadCollectionReq.builder()
+                    .sync(false)
+                    .collectionName(request.getCollectionName())
+                    .build());
         } catch (Exception e) {
             throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection failed: " + e);
         }
@@ -157,11 +161,15 @@ public class CollectionService extends BaseService {
                 CreateIndexReq createIndexReq = CreateIndexReq.builder()
                         .indexParams(Collections.singletonList(indexParam))
                         .collectionName(request.getCollectionName())
+                        .sync(false)
                         .build();
                 indexService.createIndex(blockingStub, createIndexReq);
             }
             //load collection, set async to true since no need to wait loading progress
-            loadCollection(blockingStub, LoadCollectionReq.builder().async(true).collectionName(request.getCollectionName()).build());
+            loadCollection(blockingStub, LoadCollectionReq.builder()
+                    .sync(false)
+                    .collectionName(request.getCollectionName())
+                    .build());
         }
 
         return null;
@@ -187,10 +195,6 @@ public class CollectionService extends BaseService {
         Status status = blockingStub.dropCollection(dropCollectionRequest);
         rpcUtils.handleResponse(title, status);
 
-        if (request.getAsync()) {
-            WaitForDropCollection(blockingStub, request);
-        }
-
         return null;
     }
 
@@ -289,7 +293,7 @@ public class CollectionService extends BaseService {
                 .build();
         Status status = blockingStub.loadCollection(loadCollectionRequest);
         rpcUtils.handleResponse(title, status);
-        if (!request.getAsync()) {
+        if (request.getSync()) {
             WaitForLoadCollection(blockingStub, request.getCollectionName(), request.getTimeout());
         }
 
@@ -304,7 +308,7 @@ public class CollectionService extends BaseService {
                 .build();
         Status status = blockingStub.loadCollection(loadCollectionRequest);
         rpcUtils.handleResponse(title, status);
-        if (request.getAsync()) {
+        if (request.getSync()) {
             WaitForLoadCollection(blockingStub, request.getCollectionName(), request.getTimeout());
         }
 
@@ -318,9 +322,6 @@ public class CollectionService extends BaseService {
                 .build();
         Status status = blockingStub.releaseCollection(releaseCollectionRequest);
         rpcUtils.handleResponse(title, status);
-        if (request.getAsync()) {
-            waitForCollectionRelease(blockingStub, request);
-        }
 
         return null;
     }
@@ -407,75 +408,28 @@ public class CollectionService extends BaseService {
                 .build();
     }
 
-    public void waitForCollectionRelease(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, ReleaseCollectionReq request) {
-        boolean isLoaded = true;
-        long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
-
-        while (isLoaded) {
-            // Call the getLoadState method
-            isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder().collectionName(request.getCollectionName()).build());
-            if (isLoaded) {
-                // Check if timeout is exceeded
-                if (System.currentTimeMillis() - startTime > request.getTimeout()) {
-                    throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
-                }
-                // Wait for a certain period before checking again
-                try {
-                    Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
-                } catch (InterruptedException e) {
-                    Thread.currentThread().interrupt();
-                    System.out.println("Thread was interrupted, Failed to complete operation");
-                    return; // or handle interruption appropriately
-                }
-            }
-        }
-    }
-
     private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
                                        String collectionName, long timeoutMs) {
-        boolean isLoaded = false;
         long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
 
-        while (!isLoaded) {
+        while (true) {
             // Call the getLoadState method
-            isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder().collectionName(collectionName).build());
-            if (!isLoaded) {
-                // Check if timeout is exceeded
-                if (System.currentTimeMillis() - startTime > timeoutMs) {
-                    throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
-                }
-                // Wait for a certain period before checking again
-                try {
-                    Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
-                } catch (InterruptedException e) {
-                    Thread.currentThread().interrupt();
-                    System.out.println("Thread was interrupted, Failed to complete operation");
-                    return; // or handle interruption appropriately
-                }
+            boolean isLoaded = getLoadState(blockingStub, GetLoadStateReq.builder().collectionName(collectionName).build());
+            if (isLoaded) {
+                return;
             }
-        }
-    }
-
-    private void WaitForDropCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DropCollectionReq request) {
-        boolean hasCollection = true;
-        long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
 
-        while (hasCollection) {
-            // Call the getLoadState method
-            hasCollection = hasCollection(blockingStub, HasCollectionReq.builder().collectionName(request.getCollectionName()).build());
-            if (hasCollection) {
-                // Check if timeout is exceeded
-                if (System.currentTimeMillis() - startTime > request.getTimeout()) {
-                    throw new MilvusClientException(ErrorCode.SERVER_ERROR, "drop collection timeout");
-                }
-                // Wait for a certain period before checking again
-                try {
-                    Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
-                } catch (InterruptedException e) {
-                    Thread.currentThread().interrupt();
-                    System.out.println("Thread was interrupted, Failed to complete operation");
-                    return; // or handle interruption appropriately
-                }
+            // Check if timeout is exceeded
+            if (System.currentTimeMillis() - startTime > timeoutMs) {
+                throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection timeout");
+            }
+            // Wait for a certain period before checking again
+            try {
+                Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                System.out.println("Thread was interrupted, Failed to complete operation");
+                return; // or handle interruption appropriately
             }
         }
     }

+ 1 - 0
sdk-core/src/main/java/io/milvus/v2/service/collection/request/DropCollectionReq.java

@@ -27,6 +27,7 @@ import lombok.experimental.SuperBuilder;
 @SuperBuilder
 public class DropCollectionReq {
     private String collectionName;
+    @Deprecated
     @Builder.Default
     private Boolean async = Boolean.TRUE;
     @Builder.Default

+ 22 - 1
sdk-core/src/main/java/io/milvus/v2/service/collection/request/LoadCollectionReq.java

@@ -32,10 +32,13 @@ public class LoadCollectionReq {
     private String collectionName;
     @Builder.Default
     private Integer numReplicas = 1;
+    @Deprecated
     @Builder.Default
     private Boolean async = Boolean.FALSE;
     @Builder.Default
-    private Long timeout = 60000L;
+    private Boolean sync = Boolean.TRUE; // wait the collection to be fully loaded. "async" is deprecated, use "sync" instead
+    @Builder.Default
+    private Long timeout = 60000L; // timeout value for waiting the collection to be fully loaded
     @Builder.Default
     private Boolean refresh = Boolean.FALSE;
     @Builder.Default
@@ -44,4 +47,22 @@ public class LoadCollectionReq {
     private Boolean skipLoadDynamicField = Boolean.FALSE;
     @Builder.Default
     private List<String> resourceGroups = new ArrayList<>();
+
+    public static abstract class LoadCollectionReqBuilder<C extends LoadCollectionReq, B extends LoadCollectionReq.LoadCollectionReqBuilder<C, B>> {
+        public B async(Boolean async) {
+            this.async$value = async;
+            this.async$set = true;
+            this.sync$value = !async;
+            this.sync$set = true;
+            return self();
+        }
+
+        public B sync(Boolean sync) {
+            this.sync$value = sync;
+            this.sync$set = true;
+            this.async$value = !sync;
+            this.async$set = true;
+            return self();
+        }
+    }
 }

+ 21 - 1
sdk-core/src/main/java/io/milvus/v2/service/collection/request/RefreshLoadReq.java

@@ -30,5 +30,25 @@ public class RefreshLoadReq {
     @Builder.Default
     private Boolean async = Boolean.TRUE;
     @Builder.Default
-    private Long timeout = 60000L;
+    private Boolean sync = Boolean.TRUE; // wait the collection to be fully loaded. "async" is deprecated, use "sync" instead
+    @Builder.Default
+    private Long timeout = 60000L; // timeout value for waiting the collection to be fully loaded
+
+    public static abstract class RefreshLoadReqBuilder<C extends RefreshLoadReq, B extends RefreshLoadReq.RefreshLoadReqBuilder<C, B>> {
+        public B async(Boolean async) {
+            this.async$value = async;
+            this.async$set = true;
+            this.sync$value = !async;
+            this.sync$set = true;
+            return self();
+        }
+
+        public B sync(Boolean sync) {
+            this.sync$value = sync;
+            this.sync$set = true;
+            this.async$value = !sync;
+            this.async$set = true;
+            return self();
+        }
+    }
 }

+ 1 - 0
sdk-core/src/main/java/io/milvus/v2/service/collection/request/ReleaseCollectionReq.java

@@ -27,6 +27,7 @@ import lombok.experimental.SuperBuilder;
 @SuperBuilder
 public class ReleaseCollectionReq {
     private String collectionName;
+    @Deprecated
     @Builder.Default
     private Boolean async = Boolean.TRUE;
     @Builder.Default

+ 62 - 0
sdk-core/src/main/java/io/milvus/v2/service/index/IndexService.java

@@ -23,6 +23,7 @@ import com.google.gson.JsonObject;
 import io.milvus.grpc.*;
 import io.milvus.param.Constant;
 import io.milvus.param.ParamUtils;
+import io.milvus.v2.common.IndexBuildState;
 import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.exception.ErrorCode;
 import io.milvus.v2.exception.MilvusClientException;
@@ -78,6 +79,10 @@ public class IndexService extends BaseService {
 
             Status status = blockingStub.createIndex(builder.build());
             rpcUtils.handleResponse(title, status);
+            if (request.getSync()) {
+                WaitForIndexComplete(blockingStub, request.getCollectionName(), indexParam.getFieldName(),
+                        indexParam.getIndexName(), request.getTimeout());
+            }
         }
 
         return null;
@@ -180,4 +185,61 @@ public class IndexService extends BaseService {
         });
         return indexNames;
     }
+
+    private void WaitForIndexComplete(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
+                                      String collectionName, String fieldName, String indexName, long timeoutMs) {
+        long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
+
+        // alloc a timestamp from the server, the DescribeIndex() will use this timestamp to check the segments
+        // which are generated before this timestamp.
+        AllocTimestampResponse allocTsResp = blockingStub.allocTimestamp(AllocTimestampRequest.newBuilder().build());
+        rpcUtils.handleResponse("AllocTimestampRequest", allocTsResp.getStatus());
+        long serverTs = allocTsResp.getTimestamp();
+
+        while (true) {
+            DescribeIndexResp response = describeIndex(blockingStub, DescribeIndexReq.builder()
+                    .collectionName(collectionName)
+                    .fieldName(fieldName)
+                    .indexName(indexName)
+                    .timestamp(serverTs)
+                    .build());
+            List<DescribeIndexResp.IndexDesc> indices = response.getIndexDescriptions();
+            DescribeIndexResp.IndexDesc desc = null;
+            if (indices.size() == 1) {
+                desc = indices.get(0);
+            } else {
+                for (DescribeIndexResp.IndexDesc index : indices) {
+                    if (fieldName.equals(index.getFieldName())) {
+                        desc = index;
+                        break;
+                    }
+                }
+            }
+
+            if (desc == null) {
+                String msg = String.format("Failed to describe the index '%s' of field '%s' from serv side", fieldName, indexName);
+                throw new MilvusClientException(ErrorCode.SERVER_ERROR, msg);
+            }
+
+            if (desc.getIndexState() == IndexBuildState.Finished) {
+                return;
+            } else if (desc.getIndexState() == IndexBuildState.Failed) {
+                String msg = "Index is failed, reason: " + desc.getIndexFailedReason();
+                throw new MilvusClientException(ErrorCode.SERVER_ERROR, msg);
+            }
+
+            // Check if timeout is exceeded
+            if (System.currentTimeMillis() - startTime > timeoutMs) {
+                throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Create index timeout");
+            }
+            // Wait for a certain period before checking again
+            try {
+                Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                System.out.println("Thread was interrupted, failed to complete operation");
+                return; // or handle interruption appropriately
+            }
+        }
+    }
 }

+ 5 - 0
sdk-core/src/main/java/io/milvus/v2/service/index/request/CreateIndexReq.java

@@ -20,6 +20,7 @@
 package io.milvus.v2.service.index.request;
 
 import io.milvus.v2.common.IndexParam;
+import lombok.Builder;
 import lombok.Data;
 import lombok.NonNull;
 import lombok.experimental.SuperBuilder;
@@ -33,4 +34,8 @@ public class CreateIndexReq {
     @NonNull
     private String collectionName;
     private List<IndexParam> indexParams;
+    @Builder.Default
+    private Boolean sync = Boolean.TRUE; // wait the index to complete
+    @Builder.Default
+    private Long timeout = 60000L; // timeout value for waiting the index to complete
 }

+ 3 - 0
sdk-core/src/main/java/io/milvus/v2/service/index/request/DescribeIndexReq.java

@@ -19,6 +19,7 @@
 
 package io.milvus.v2.service.index.request;
 
+import lombok.Builder;
 import lombok.Data;
 import lombok.NonNull;
 import lombok.experimental.SuperBuilder;
@@ -30,4 +31,6 @@ public class DescribeIndexReq {
     private String collectionName;
     private String fieldName;
     private String indexName;
+    @Builder.Default
+    private Long timestamp = 0L; // only check segments generated before this timestamp. all the segments will be checked if this value is zero.
 }

+ 35 - 0
sdk-core/src/main/java/io/milvus/v2/service/partition/PartitionService.java

@@ -20,6 +20,8 @@
 package io.milvus.v2.service.partition;
 
 import io.milvus.grpc.*;
+import io.milvus.v2.exception.ErrorCode;
+import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.BaseService;
 import io.milvus.v2.service.partition.request.*;
 import io.milvus.v2.service.partition.response.*;
@@ -107,6 +109,9 @@ public class PartitionService extends BaseService {
                 .build();
         Status status = blockingStub.loadPartitions(loadPartitionsRequest);
         rpcUtils.handleResponse(title, status);
+        if (request.getSync()) {
+            WaitForLoadPartitions(blockingStub, request.getCollectionName(), request.getPartitionNames(), request.getTimeout());
+        }
 
         return null;
     }
@@ -122,4 +127,34 @@ public class PartitionService extends BaseService {
 
         return null;
     }
+
+    private void WaitForLoadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
+                                       String collectionName, List<String> partitions, long timeoutMs) {
+        long startTime = System.currentTimeMillis(); // Capture start time/ Timeout in milliseconds (60 seconds)
+
+        while (true) {
+            GetLoadingProgressResponse response = blockingStub.getLoadingProgress(GetLoadingProgressRequest.newBuilder()
+                    .setCollectionName(collectionName)
+                    .addAllPartitionNames(partitions)
+                    .build());
+            String title = String.format("GetLoadingProgressRequest collectionName:%s", collectionName);
+            rpcUtils.handleResponse(title, response.getStatus());
+            if (response.getProgress() >= 100) {
+                return;
+            }
+
+            // Check if timeout is exceeded
+            if (System.currentTimeMillis() - startTime > timeoutMs) {
+                throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load partitions timeout");
+            }
+            // Wait for a certain period before checking again
+            try {
+                Thread.sleep(500); // Sleep for 0.5 second. Adjust this value as needed.
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                System.out.println("Thread was interrupted, failed to complete operation");
+                return; // or handle interruption appropriately
+            }
+        }
+    }
 }

+ 4 - 0
sdk-core/src/main/java/io/milvus/v2/service/partition/request/LoadPartitionsReq.java

@@ -35,6 +35,10 @@ public class LoadPartitionsReq {
     @Builder.Default
     private Integer numReplicas = 1;
     @Builder.Default
+    private Boolean sync = Boolean.TRUE; // wait the partitions to be fully loaded
+    @Builder.Default
+    private Long timeout = 60000L; // timeout value for waiting the partitions to be fully loaded
+    @Builder.Default
     private Boolean refresh = Boolean.FALSE;
     @Builder.Default
     private List<String> loadFields = new ArrayList<>();

+ 2 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/response/QueryResp.java

@@ -19,6 +19,7 @@
 
 package io.milvus.v2.service.vector.response;
 
+import lombok.Builder;
 import lombok.Data;
 import lombok.experimental.SuperBuilder;
 
@@ -29,6 +30,7 @@ import java.util.Map;
 @SuperBuilder
 public class QueryResp {
     private List<QueryResult> queryResults;
+    @Builder.Default
     private long sessionTs = 1L; // default eventually ts
 
     @Data

+ 2 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java

@@ -19,6 +19,7 @@
 
 package io.milvus.v2.service.vector.response;
 
+import lombok.Builder;
 import lombok.Data;
 import lombok.experimental.SuperBuilder;
 
@@ -29,6 +30,7 @@ import java.util.Map;
 @SuperBuilder
 public class SearchResp {
     private List<List<SearchResult>> searchResults;
+    @Builder.Default
     private long sessionTs = 1L; // default eventually ts
     private List<Float> recalls;
 

+ 15 - 2
sdk-core/src/test/java/io/milvus/v2/BaseTest.java

@@ -76,7 +76,7 @@ public class BaseTest {
                 .setCreatedUtcTimestamp(0)
                 .build();
 
-        IndexDescription index = IndexDescription.newBuilder()
+        IndexDescription index1 = IndexDescription.newBuilder()
                 .setIndexName("test")
                 .setFieldName("vector")
                 .addParams(KeyValuePair.newBuilder()
@@ -87,10 +87,21 @@ public class BaseTest {
                         .setKey("metric_type")
                         .setValue("L2")
                         .build())
+                .setState(IndexState.Finished)
+                .build();
+        IndexDescription index2 = IndexDescription.newBuilder()
+                .setIndexName("age")
+                .setFieldName("age")
+                .addParams(KeyValuePair.newBuilder()
+                        .setKey("index_type")
+                        .setValue("INVERTED")
+                        .build())
+                .setState(IndexState.Finished)
                 .build();
         DescribeIndexResponse describeIndexResponse = DescribeIndexResponse.newBuilder()
                 .setStatus(successStatus)
-                .addIndexDescriptions(index)
+                .addIndexDescriptions(index1)
+                .addIndexDescriptions(index2)
                 .build();
         when(blockingStub.listDatabases(any())).thenReturn(ListDatabasesResponse.newBuilder().setStatus(successStatus).addDbNames("default").build());
         // collection api
@@ -104,6 +115,7 @@ public class BaseTest {
         when(blockingStub.describeCollection(any())).thenReturn(describeCollectionResponse);
         when(blockingStub.renameCollection(any())).thenReturn(successStatus);
         when(blockingStub.getCollectionStatistics(any())).thenReturn(GetCollectionStatisticsResponse.newBuilder().addStats(KeyValuePair.newBuilder().setKey("row_count").setValue("10").build()).setStatus(successStatus).build());
+        when(blockingStub.getLoadingProgress(any())).thenReturn(GetLoadingProgressResponse.newBuilder().setStatus(successStatus).setProgress(100).build());
 
         // index api
         when(blockingStub.createIndex(any())).thenReturn(successStatus);
@@ -151,6 +163,7 @@ public class BaseTest {
         when(blockingStub.alterAlias(any())).thenReturn(successStatus);
         when(blockingStub.describeAlias(any())).thenReturn(DescribeAliasResponse.newBuilder().setStatus(successStatus).build());
         when(blockingStub.listAliases(any())).thenReturn(ListAliasesResponse.newBuilder().setStatus(successStatus).addAliases("test").build());
+        when(blockingStub.allocTimestamp(any())).thenReturn(AllocTimestampResponse.newBuilder().setStatus(successStatus).setTimestamp(1L).build());
     }
     @AfterEach
     public void tearDown() throws InterruptedException {

+ 19 - 11
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -299,21 +299,10 @@ class MilvusClientV2DockerTest {
                 .dimension(DIMENSION)
                 .build());
 
-        Map<String,Object> extraParams = new HashMap<>();
-        extraParams.put("M",16);
-        extraParams.put("efConstruction",64);
-        IndexParam indexParam = IndexParam.builder()
-                .fieldName(vectorFieldName)
-                .indexType(IndexParam.IndexType.HNSW)
-                .metricType(IndexParam.MetricType.COSINE)
-                .extraParams(extraParams)
-                .build();
-
         CreateCollectionReq requestCreate = CreateCollectionReq.builder()
                 .collectionName(randomCollectionName)
                 .description("dummy")
                 .collectionSchema(collectionSchema)
-                .indexParams(Collections.singletonList(indexParam))
                 .build();
         client.createCollection(requestCreate);
 
@@ -338,6 +327,25 @@ class MilvusClientV2DockerTest {
         // there is a segment is flushed by the flush() interface, there could be a compaction task created
         Assertions.assertTrue(compactResp.getCompactionID() == -1L || compactResp.getCompactionID() > 0L);
 
+        // create index
+        Map<String,Object> extraParams = new HashMap<>();
+        extraParams.put("M", 64);
+        extraParams.put("efConstruction", 200);
+        IndexParam indexParam = IndexParam.builder()
+                .fieldName(vectorFieldName)
+                .indexType(IndexParam.IndexType.HNSW)
+                .metricType(IndexParam.MetricType.COSINE)
+                .extraParams(extraParams)
+                .build();
+        client.createIndex(CreateIndexReq.builder()
+                .collectionName(randomCollectionName)
+                .indexParams(Collections.singletonList(indexParam))
+                .build());
+
+        client.loadCollection(LoadCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .build());
+
         // create partition, upsert one row to the partition
         String partitionName = "PPP";
         client.createPartition(CreatePartitionReq.builder()