Browse Source

QueryIterator/SearchIterator support retry (#1270)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 4 months ago
parent
commit
70c08026ca

+ 1 - 2
src/main/java/io/milvus/orm/iterator/QueryIterator.java

@@ -190,8 +190,7 @@ public class QueryIterator {
                 .build();
 
         QueryRequest queryRequest = ParamUtils.convertQueryParam(queryParam);
-        QueryResults response = blockingStub.query(queryRequest);
-
+        QueryResults response = rpcUtils.retry(()->blockingStub.query(queryRequest));
         String title = String.format("QueryRequest collectionName:%s", queryIteratorParam.getCollectionName());
         rpcUtils.handleResponse(title, response.getStatus());
 

+ 1 - 1
src/main/java/io/milvus/orm/iterator/SearchIterator.java

@@ -257,7 +257,7 @@ public class SearchIterator {
         fillVectorsByPlType(searchParamBuilder);
 
         SearchRequest searchRequest = ParamUtils.convertSearchParam(searchParamBuilder.build());
-        SearchResults response = blockingStub.search(searchRequest);
+        SearchResults response = rpcUtils.retry(()->blockingStub.search(searchRequest));
 
         String title = String.format("SearchRequest collectionName:%s", searchIteratorParam.getCollectionName());
         rpcUtils.handleResponse(title, response.getStatus());

+ 80 - 189
src/main/java/io/milvus/v2/client/MilvusClientV2.java

@@ -20,14 +20,10 @@
 package io.milvus.v2.client;
 
 import io.grpc.ManagedChannel;
-import io.grpc.Status;
-import io.grpc.StatusRuntimeException;
 import io.milvus.grpc.*;
 import io.milvus.orm.iterator.QueryIterator;
 import io.milvus.orm.iterator.SearchIterator;
 
-import io.milvus.v2.exception.ErrorCode;
-import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.database.DatabaseService;
 import io.milvus.v2.service.database.request.*;
 import io.milvus.v2.service.database.response.*;
@@ -53,13 +49,13 @@ import io.milvus.v2.service.vector.VectorService;
 import io.milvus.v2.service.vector.request.*;
 import io.milvus.v2.service.vector.response.*;
 import io.milvus.v2.utils.ClientUtils;
+import io.milvus.v2.utils.RpcUtils;
 import lombok.NonNull;
 import lombok.Setter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.List;
-import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
 
 public class MilvusClientV2 {
@@ -76,6 +72,7 @@ public class MilvusClientV2 {
     private final RBACService rbacService = new RBACService();
     private final ResourceGroupService rgroupService = new ResourceGroupService();
     private final UtilityService utilityService = new UtilityService();
+    private RpcUtils rpcUtils = new RpcUtils();
     private ConnectConfig connectConfig;
     private RetryConfig retryConfig = RetryConfig.builder().build();
 
@@ -159,113 +156,7 @@ public class MilvusClientV2 {
     }
 
     public void retryConfig(RetryConfig retryConfig) {
-        this.retryConfig = retryConfig;
-    }
-
-    private <T> T retry(Callable<T> callable) {
-        int maxRetryTimes = retryConfig.getMaxRetryTimes();
-        // no retry, direct call the method
-        if (maxRetryTimes <= 1) {
-            try {
-                return callable.call();
-            } catch (StatusRuntimeException e) {
-                throw new MilvusClientException(ErrorCode.RPC_ERROR, e); // rpc error
-            } catch (MilvusClientException e) {
-                throw e; // server error or client error
-            } catch (Exception e) {
-                throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); // others error treated as client error
-            }
-        }
-
-        // method to check timeout
-        long begin = System.currentTimeMillis();
-        long maxRetryTimeoutMs = retryConfig.getMaxRetryTimeoutMs();
-        Callable<Boolean> timeoutChecker = ()->{
-            long current = System.currentTimeMillis();
-            long cost = (current - begin);
-            if (maxRetryTimeoutMs > 0 && cost >= maxRetryTimeoutMs) {
-                return Boolean.TRUE;
-            }
-            return Boolean.FALSE;
-        };
-
-        // retry within timeout
-        long retryIntervalMs = retryConfig.getInitialBackOffMs();
-        for (int k = 1; k <= maxRetryTimes; k++) {
-            try {
-                return callable.call();
-            } catch (StatusRuntimeException e) {
-                Status.Code code = e.getStatus().getCode();
-                if (code == Status.DEADLINE_EXCEEDED.getCode()
-                        || code == Status.PERMISSION_DENIED.getCode()
-                        || code == Status.UNAUTHENTICATED.getCode()
-                        || code == Status.INVALID_ARGUMENT.getCode()
-                        || code == Status.ALREADY_EXISTS.getCode()
-                        || code == Status.RESOURCE_EXHAUSTED.getCode()
-                        || code == Status.UNIMPLEMENTED.getCode()) {
-                    String msg = String.format("Encounter rpc error that cannot be retried, reason: %s", e);
-                    logger.error(msg);
-                    throw new MilvusClientException(ErrorCode.RPC_ERROR, msg); // throw rpc error
-                }
-
-                try {
-                    if (timeoutChecker.call() == Boolean.TRUE) {
-                        String msg = String.format("Retry timeout: %dms, maxRetry:%d, retries: %d, reason: %s",
-                                maxRetryTimeoutMs, maxRetryTimes, k, e.getMessage());
-                        logger.warn(msg);
-                        throw new MilvusClientException(ErrorCode.TIMEOUT, msg); // exit retry for timeout
-                    }
-                } catch (Exception ignored) {
-                }
-            } catch (MilvusClientException e) {
-                try {
-                    if (timeoutChecker.call() == Boolean.TRUE) {
-                        String msg = String.format("Retry timeout: %dms, maxRetry:%d, retries: %d, reason: %s",
-                                maxRetryTimeoutMs, maxRetryTimes, k, e.getMessage());
-                        logger.warn(msg);
-                        throw new MilvusClientException(ErrorCode.TIMEOUT, msg); // exit retry for timeout
-                    }
-                } catch (Exception ignored) {
-                }
-
-                // for server-side returned error, only retry for rate limit
-                // in new error codes of v2.3, rate limit error value is 8
-                if (retryConfig.isRetryOnRateLimit() &&
-                        (e.getLegacyServerCode() == io.milvus.grpc.ErrorCode.RateLimit.getNumber() ||
-                                e.getServerErrCode() == 8)) {
-                    // cannot be retried
-                } else {
-                    throw e; // exit retry, throw the error
-                }
-            } catch (Exception e) {
-                throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); // others error treated as client error
-            }
-
-            try {
-                if (k >= maxRetryTimes) {
-                    // finish retry loop, return the response of the last retry
-                    String msg = String.format("Finish %d retry times, stop retry", maxRetryTimes);
-                    logger.warn(msg);
-                    throw new MilvusClientException(ErrorCode.TIMEOUT, msg); // exceed max time, exit retry
-                } else {
-                    // sleep for interval
-                    // print log, follow the pymilvus logic
-                    if (k > 3) {
-                        logger.warn(String.format("Retry(%d) with interval %dms", k, retryIntervalMs));
-                    }
-                    TimeUnit.MILLISECONDS.sleep(retryIntervalMs);
-                }
-
-                // reset the next interval value
-                retryIntervalMs = retryIntervalMs*retryConfig.getBackOffMultiplier();
-                if (retryIntervalMs > retryConfig.getMaxBackOffMs()) {
-                    retryIntervalMs = retryConfig.getMaxBackOffMs();
-                }
-            } catch (Exception ignored) {
-            }
-        }
-
-        return null;
+        rpcUtils.retryConfig(retryConfig);
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -294,21 +185,21 @@ public class MilvusClientV2 {
      * @param request create database request
      */
     public void createDatabase(CreateDatabaseReq request) {
-        retry(()-> databaseService.createDatabase(this.getRpcStub(), request));
+        rpcUtils.retry(()-> databaseService.createDatabase(this.getRpcStub(), request));
     }
     /**
      * Drops a database. Note that this method drops all data in the database.
      * @param request drop database request
      */
     public void dropDatabase(DropDatabaseReq request) {
-        retry(()-> databaseService.dropDatabase(this.getRpcStub(), request));
+        rpcUtils.retry(()-> databaseService.dropDatabase(this.getRpcStub(), request));
     }
     /**
      * List all databases.
      * @return List of String database names
      */
     public ListDatabasesResp listDatabases() {
-        return retry(()-> databaseService.listDatabases(this.getRpcStub()));
+        return rpcUtils.retry(()-> databaseService.listDatabases(this.getRpcStub()));
     }
     /**
      * Alter database with key value pair.
@@ -327,14 +218,14 @@ public class MilvusClientV2 {
      * @param request alter database properties request
      */
     public void alterDatabaseProperties(AlterDatabasePropertiesReq request) {
-        retry(()-> databaseService.alterDatabaseProperties(this.getRpcStub(), request));
+        rpcUtils.retry(()-> databaseService.alterDatabaseProperties(this.getRpcStub(), request));
     }
     /**
      * drop a database's properties.
      * @param request alter database properties request
      */
     public void dropDatabaseProperties(DropDatabasePropertiesReq request) {
-        retry(()-> databaseService.dropDatabaseProperties(this.getRpcStub(), request));
+        rpcUtils.retry(()-> databaseService.dropDatabaseProperties(this.getRpcStub(), request));
     }
     /**
      * Show detail of database base, such as replica number and resource groups. (Available from Milvus v2.4.4)
@@ -343,7 +234,7 @@ public class MilvusClientV2 {
      * @return DescribeDatabaseResp
      */
     public DescribeDatabaseResp describeDatabase(DescribeDatabaseReq request) {
-        return retry(()-> databaseService.describeDatabase(this.getRpcStub(), request));
+        return rpcUtils.retry(()-> databaseService.describeDatabase(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -354,7 +245,7 @@ public class MilvusClientV2 {
      * @param request create collection request
      */
     public void createCollection(CreateCollectionReq request) {
-        retry(()-> collectionService.createCollection(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.createCollection(this.getRpcStub(), request));
     }
     /**
      * Creates a collection schema.
@@ -369,7 +260,7 @@ public class MilvusClientV2 {
      * @return List of String collection names
      */
     public ListCollectionsResp listCollections() {
-        return retry(()-> collectionService.listCollections(this.getRpcStub()));
+        return rpcUtils.retry(()-> collectionService.listCollections(this.getRpcStub()));
     }
     /**
      * Drops a collection in Milvus.
@@ -377,7 +268,7 @@ public class MilvusClientV2 {
      * @param request drop collection request
      */
     public void dropCollection(DropCollectionReq request) {
-        retry(()-> collectionService.dropCollection(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.dropCollection(this.getRpcStub(), request));
     }
     /**
      * Alter a collection in Milvus.
@@ -399,7 +290,7 @@ public class MilvusClientV2 {
      * @param request alter collection properties request
      */
     public void alterCollectionProperties(AlterCollectionPropertiesReq request) {
-        retry(()-> collectionService.alterCollectionProperties(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.alterCollectionProperties(this.getRpcStub(), request));
     }
     /**
      * Alter a field's properties .
@@ -407,14 +298,14 @@ public class MilvusClientV2 {
      * @param request alter field properties request
      */
     public void alterCollectionField(AlterCollectionFieldReq request) {
-        retry(()-> collectionService.alterCollectionField(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.alterCollectionField(this.getRpcStub(), request));
     }
     /**
      * drop a collection's properties.
      * @param request drop collection properties request
      */
     public void dropCollectionProperties(DropCollectionPropertiesReq request) {
-        retry(()-> collectionService.dropCollectionProperties(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.dropCollectionProperties(this.getRpcStub(), request));
     }
     /**
      * Checks whether a collection exists in Milvus.
@@ -423,7 +314,7 @@ public class MilvusClientV2 {
      * @return Boolean
      */
     public Boolean hasCollection(HasCollectionReq request) {
-        return retry(()-> collectionService.hasCollection(this.getRpcStub(), request));
+        return rpcUtils.retry(()-> collectionService.hasCollection(this.getRpcStub(), request));
     }
     /**
      * Gets the collection info in Milvus.
@@ -432,7 +323,7 @@ public class MilvusClientV2 {
      * @return DescribeCollectionResp
      */
     public DescribeCollectionResp describeCollection(DescribeCollectionReq request) {
-        return retry(()-> collectionService.describeCollection(this.getRpcStub(), request));
+        return rpcUtils.retry(()-> collectionService.describeCollection(this.getRpcStub(), request));
     }
     /**
      * get collection stats for a collection in Milvus.
@@ -441,7 +332,7 @@ public class MilvusClientV2 {
      * @return GetCollectionStatsResp
      */
     public GetCollectionStatsResp getCollectionStats(GetCollectionStatsReq request) {
-        return retry(()-> collectionService.getCollectionStats(this.getRpcStub(), request));
+        return rpcUtils.retry(()-> collectionService.getCollectionStats(this.getRpcStub(), request));
     }
     /**
      * rename collection in a collection in Milvus.
@@ -449,7 +340,7 @@ public class MilvusClientV2 {
      * @param request rename collection request
      */
     public void renameCollection(RenameCollectionReq request) {
-        retry(()-> collectionService.renameCollection(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.renameCollection(this.getRpcStub(), request));
     }
     /**
      * Loads a collection into memory in Milvus.
@@ -457,7 +348,7 @@ public class MilvusClientV2 {
      * @param request load collection request
      */
     public void loadCollection(LoadCollectionReq request) {
-        retry(()-> collectionService.loadCollection(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.loadCollection(this.getRpcStub(), request));
     }
     /**
      * Refresh loads a collection. Mainly used when there are new segments generated by bulkinsert request.
@@ -467,7 +358,7 @@ public class MilvusClientV2 {
      * @param request refresh load collection request
      */
     public void refreshLoad(RefreshLoadReq request) {
-        retry(()-> collectionService.refreshLoad(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.refreshLoad(this.getRpcStub(), request));
     }
     /**
      * Releases a collection from memory in Milvus.
@@ -475,7 +366,7 @@ public class MilvusClientV2 {
      * @param request release collection request
      */
     public void releaseCollection(ReleaseCollectionReq request) {
-        retry(()-> collectionService.releaseCollection(this.getRpcStub(), request));
+        rpcUtils.retry(()-> collectionService.releaseCollection(this.getRpcStub(), request));
     }
     /**
      * Checks whether a collection is loaded in Milvus.
@@ -484,7 +375,7 @@ public class MilvusClientV2 {
      * @return Boolean
      */
     public Boolean getLoadState(GetLoadStateReq request) {
-        return retry(()->collectionService.getLoadState(this.getRpcStub(), request));
+        return rpcUtils.retry(()->collectionService.getLoadState(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -496,7 +387,7 @@ public class MilvusClientV2 {
      * @param request create index request
      */
     public void createIndex(CreateIndexReq request) {
-        retry(()->indexService.createIndex(this.getRpcStub(), request));
+        rpcUtils.retry(()->indexService.createIndex(this.getRpcStub(), request));
     }
     /**
      * Drops an index for a specified field in a collection in Milvus.
@@ -504,7 +395,7 @@ public class MilvusClientV2 {
      * @param request drop index request
      */
     public void dropIndex(DropIndexReq request) {
-        retry(()->indexService.dropIndex(this.getRpcStub(), request));
+        rpcUtils.retry(()->indexService.dropIndex(this.getRpcStub(), request));
     }
     /**
      * Alter an index in Milvus.
@@ -527,14 +418,14 @@ public class MilvusClientV2 {
      * @param request alter index request
      */
     public void alterIndexProperties(AlterIndexPropertiesReq request) {
-        retry(()->indexService.alterIndexProperties(this.getRpcStub(), request));
+        rpcUtils.retry(()->indexService.alterIndexProperties(this.getRpcStub(), request));
     }
     /**
      * drop an index's properties.
      * @param request drop index properties request
      */
     public void dropIndexProperties(DropIndexPropertiesReq request) {
-        retry(()-> indexService.dropIndexProperties(this.getRpcStub(), request));
+        rpcUtils.retry(()-> indexService.dropIndexProperties(this.getRpcStub(), request));
     }
     /**
      * Checks whether an index exists for a specified field in a collection in Milvus.
@@ -543,7 +434,7 @@ public class MilvusClientV2 {
      * @return DescribeIndexResp
      */
     public DescribeIndexResp describeIndex(DescribeIndexReq request) {
-        return retry(()->indexService.describeIndex(this.getRpcStub(), request));
+        return rpcUtils.retry(()->indexService.describeIndex(this.getRpcStub(), request));
     }
     /**
      * Lists all indexes in a collection in Milvus.
@@ -552,7 +443,7 @@ public class MilvusClientV2 {
      * @return List of String names of the indexes
      */
     public List<String> listIndexes(ListIndexesReq request) {
-        return retry(()->indexService.listIndexes(this.getRpcStub(), request));
+        return rpcUtils.retry(()->indexService.listIndexes(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -565,7 +456,7 @@ public class MilvusClientV2 {
      * @return InsertResp
      */
     public InsertResp insert(InsertReq request) {
-        return retry(()->vectorService.insert(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.insert(this.getRpcStub(), request));
     }
     /**
      * Upsert vectors into a collection in Milvus.
@@ -574,7 +465,7 @@ public class MilvusClientV2 {
      * @return UpsertResp
      */
     public UpsertResp upsert(UpsertReq request) {
-        return retry(()->vectorService.upsert(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.upsert(this.getRpcStub(), request));
     }
     /**
      * Deletes vectors in a collection in Milvus.
@@ -583,7 +474,7 @@ public class MilvusClientV2 {
      * @return DeleteResp
      */
     public DeleteResp delete(DeleteReq request) {
-        return retry(()->vectorService.delete(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.delete(this.getRpcStub(), request));
     }
     /**
      * Gets vectors in a collection in Milvus.
@@ -592,7 +483,7 @@ public class MilvusClientV2 {
      * @return GetResp
      */
     public GetResp get(GetReq request) {
-        return retry(()->vectorService.get(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.get(this.getRpcStub(), request));
     }
 
     /**
@@ -602,7 +493,7 @@ public class MilvusClientV2 {
      * @return QueryResp
      */
     public QueryResp query(QueryReq request) {
-        return retry(()->vectorService.query(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.query(this.getRpcStub(), request));
     }
     /**
      * Searches vectors in a collection in Milvus.
@@ -611,7 +502,7 @@ public class MilvusClientV2 {
      * @return SearchResp
      */
     public SearchResp search(SearchReq request) {
-        return retry(()->vectorService.search(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.search(this.getRpcStub(), request));
     }
     /**
      * Conducts multi vector similarity search with a ranker for rearrangement.
@@ -620,7 +511,7 @@ public class MilvusClientV2 {
      * @return SearchResp
      */
     public SearchResp hybridSearch(HybridSearchReq request) {
-        return retry(()->vectorService.hybridSearch(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.hybridSearch(this.getRpcStub(), request));
     }
 
     /**
@@ -631,7 +522,7 @@ public class MilvusClientV2 {
      * @return {status:result code,data: QueryIterator}
      */
     public QueryIterator queryIterator(QueryIteratorReq request) {
-        return retry(()->vectorService.queryIterator(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.queryIterator(this.getRpcStub(), request));
     }
 
     /**
@@ -641,7 +532,7 @@ public class MilvusClientV2 {
      * @return {status:result code, data: SearchIterator}
      */
     public SearchIterator searchIterator(SearchIteratorReq request) {
-        return retry(()->vectorService.searchIterator(this.getRpcStub(), request));
+        return rpcUtils.retry(()->vectorService.searchIterator(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -653,7 +544,7 @@ public class MilvusClientV2 {
      * @param request create partition request
      */
     public void createPartition(CreatePartitionReq request) {
-        retry(()->partitionService.createPartition(this.getRpcStub(), request));
+        rpcUtils.retry(()->partitionService.createPartition(this.getRpcStub(), request));
     }
 
     /**
@@ -662,7 +553,7 @@ public class MilvusClientV2 {
      * @param request drop partition request
      */
     public void dropPartition(DropPartitionReq request) {
-        retry(()->partitionService.dropPartition(this.getRpcStub(), request));
+        rpcUtils.retry(()->partitionService.dropPartition(this.getRpcStub(), request));
     }
 
     /**
@@ -672,7 +563,7 @@ public class MilvusClientV2 {
      * @return Boolean
      */
     public Boolean hasPartition(HasPartitionReq request) {
-        return retry(()->partitionService.hasPartition(this.getRpcStub(), request));
+        return rpcUtils.retry(()->partitionService.hasPartition(this.getRpcStub(), request));
     }
 
     /**
@@ -682,7 +573,7 @@ public class MilvusClientV2 {
      * @return List of String partition names
      */
     public List<String> listPartitions(ListPartitionsReq request) {
-        return retry(()->partitionService.listPartitions(this.getRpcStub(), request));
+        return rpcUtils.retry(()->partitionService.listPartitions(this.getRpcStub(), request));
     }
 
     /**
@@ -692,7 +583,7 @@ public class MilvusClientV2 {
      * @return GetPartitionStatsResp
      */
     public GetPartitionStatsResp getPartitionStats(GetPartitionStatsReq request) {
-        return retry(()-> partitionService.getPartitionStats(this.getRpcStub(), request));
+        return rpcUtils.retry(()-> partitionService.getPartitionStats(this.getRpcStub(), request));
     }
 
     /**
@@ -701,7 +592,7 @@ public class MilvusClientV2 {
      * @param request load partitions request
      */
     public void loadPartitions(LoadPartitionsReq request) {
-        retry(()->partitionService.loadPartitions(this.getRpcStub(), request));
+        rpcUtils.retry(()->partitionService.loadPartitions(this.getRpcStub(), request));
     }
     /**
      * Releases partitions in a collection in Milvus.
@@ -709,7 +600,7 @@ public class MilvusClientV2 {
      * @param request release partitions request
      */
     public void releasePartitions(ReleasePartitionsReq request) {
-        retry(()->partitionService.releasePartitions(this.getRpcStub(), request));
+        rpcUtils.retry(()->partitionService.releasePartitions(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -721,7 +612,7 @@ public class MilvusClientV2 {
      * @return List of String usernames
      */
     public List<String> listUsers() {
-        return retry(()->rbacService.listUsers(this.getRpcStub()));
+        return rpcUtils.retry(()->rbacService.listUsers(this.getRpcStub()));
     }
     /**
      * describe user
@@ -730,7 +621,7 @@ public class MilvusClientV2 {
      * @return DescribeUserResp
      */
     public DescribeUserResp describeUser(DescribeUserReq request) {
-        return retry(()->rbacService.describeUser(this.getRpcStub(), request));
+        return rpcUtils.retry(()->rbacService.describeUser(this.getRpcStub(), request));
     }
     /**
      * create user
@@ -738,7 +629,7 @@ public class MilvusClientV2 {
      * @param request create user request
      */
     public void createUser(CreateUserReq request) {
-        retry(()->rbacService.createUser(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.createUser(this.getRpcStub(), request));
     }
     /**
      * change password
@@ -746,7 +637,7 @@ public class MilvusClientV2 {
      * @param request change password request
      */
     public void updatePassword(UpdatePasswordReq request) {
-        retry(()->rbacService.updatePassword(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.updatePassword(this.getRpcStub(), request));
     }
     /**
      * drop user
@@ -754,7 +645,7 @@ public class MilvusClientV2 {
      * @param request drop user request
      */
     public void dropUser(DropUserReq request) {
-        retry(()->rbacService.dropUser(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.dropUser(this.getRpcStub(), request));
     }
     // role operations
     /**
@@ -763,7 +654,7 @@ public class MilvusClientV2 {
      * @return List of String role names
      */
     public List<String> listRoles() {
-        return retry(()->rbacService.listRoles(this.getRpcStub()));
+        return rpcUtils.retry(()->rbacService.listRoles(this.getRpcStub()));
     }
     /**
      * describe role
@@ -772,7 +663,7 @@ public class MilvusClientV2 {
      * @return DescribeRoleResp
      */
     public DescribeRoleResp describeRole(DescribeRoleReq request) {
-        return retry(()->rbacService.describeRole(this.getRpcStub(), request));
+        return rpcUtils.retry(()->rbacService.describeRole(this.getRpcStub(), request));
     }
     /**
      * create role
@@ -780,7 +671,7 @@ public class MilvusClientV2 {
      * @param request create role request
      */
     public void createRole(CreateRoleReq request) {
-        retry(()->rbacService.createRole(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.createRole(this.getRpcStub(), request));
     }
     /**
      * drop role
@@ -788,7 +679,7 @@ public class MilvusClientV2 {
      * @param request drop role request
      */
     public void dropRole(DropRoleReq request) {
-        retry(()->rbacService.dropRole(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.dropRole(this.getRpcStub(), request));
     }
     /**
      * grant privilege
@@ -796,7 +687,7 @@ public class MilvusClientV2 {
      * @param request grant privilege request
      */
     public void grantPrivilege(GrantPrivilegeReq request) {
-        retry(()->rbacService.grantPrivilege(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.grantPrivilege(this.getRpcStub(), request));
     }
     /**
      * revoke privilege
@@ -804,7 +695,7 @@ public class MilvusClientV2 {
      * @param request revoke privilege request
      */
     public void revokePrivilege(RevokePrivilegeReq request) {
-        retry(()->rbacService.revokePrivilege(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.revokePrivilege(this.getRpcStub(), request));
     }
     /**
      * grant role
@@ -812,7 +703,7 @@ public class MilvusClientV2 {
      * @param request grant role request
      */
     public void grantRole(GrantRoleReq request) {
-        retry(()->rbacService.grantRole(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.grantRole(this.getRpcStub(), request));
     }
     /**
      * revoke role
@@ -820,35 +711,35 @@ public class MilvusClientV2 {
      * @param request revoke role request
      */
     public void revokeRole(RevokeRoleReq request) {
-        retry(()->rbacService.revokeRole(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.revokeRole(this.getRpcStub(), request));
     }
 
     public void createPrivilegeGroup(CreatePrivilegeGroupReq request) {
-        retry(()->rbacService.createPrivilegeGroup(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.createPrivilegeGroup(this.getRpcStub(), request));
     }
 
     public void dropPrivilegeGroup(DropPrivilegeGroupReq request) {
-        retry(()->rbacService.dropPrivilegeGroup(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.dropPrivilegeGroup(this.getRpcStub(), request));
     }
 
     public ListPrivilegeGroupsResp listPrivilegeGroups(ListPrivilegeGroupsReq request) {
-        return retry(()->rbacService.listPrivilegeGroups(this.getRpcStub(), request));
+        return rpcUtils.retry(()->rbacService.listPrivilegeGroups(this.getRpcStub(), request));
     }
 
     public void addPrivilegesToGroup(AddPrivilegesToGroupReq request) {
-        retry(()->rbacService.addPrivilegesToGroup(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.addPrivilegesToGroup(this.getRpcStub(), request));
     }
 
     public void removePrivilegesFromGroup(RemovePrivilegesFromGroupReq request) {
-        retry(()->rbacService.removePrivilegesFromGroup(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.removePrivilegesFromGroup(this.getRpcStub(), request));
     }
 
     public void grantPrivilegeV2(GrantPrivilegeReqV2 request) {
-        retry(()->rbacService.grantPrivilegeV2(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.grantPrivilegeV2(this.getRpcStub(), request));
     }
 
     public void revokePrivilegeV2(RevokePrivilegeReqV2 request) {
-        retry(()->rbacService.revokePrivilegeV2(this.getRpcStub(), request));
+        rpcUtils.retry(()->rbacService.revokePrivilegeV2(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -860,7 +751,7 @@ public class MilvusClientV2 {
      * @param request {@link CreateResourceGroupReq}
      */
     public void createResourceGroup(CreateResourceGroupReq request){
-        retry(()->rgroupService.createResourceGroup(this.getRpcStub(), request));
+        rpcUtils.retry(()->rgroupService.createResourceGroup(this.getRpcStub(), request));
     }
 
     /**
@@ -869,7 +760,7 @@ public class MilvusClientV2 {
      * @param request {@link UpdateResourceGroupsReq}
      */
     public void updateResourceGroups(UpdateResourceGroupsReq request) {
-        retry(()->rgroupService.updateResourceGroups(this.getRpcStub(), request));
+        rpcUtils.retry(()->rgroupService.updateResourceGroups(this.getRpcStub(), request));
     }
 
     /**
@@ -878,7 +769,7 @@ public class MilvusClientV2 {
      * @param request {@link DropResourceGroupReq}
      */
     public void dropResourceGroup(DropResourceGroupReq request) {
-        retry(()->rgroupService.dropResourceGroup(this.getRpcStub(), request));
+        rpcUtils.retry(()->rgroupService.dropResourceGroup(this.getRpcStub(), request));
     }
 
     /**
@@ -888,7 +779,7 @@ public class MilvusClientV2 {
      * @return ListResourceGroupsResp
      */
     ListResourceGroupsResp listResourceGroups(ListResourceGroupsReq request) {
-        return retry(()->rgroupService.listResourceGroups(this.getRpcStub(), request));
+        return rpcUtils.retry(()->rgroupService.listResourceGroups(this.getRpcStub(), request));
     }
 
     /**
@@ -898,7 +789,7 @@ public class MilvusClientV2 {
      * @return DescribeResourceGroupResp
      */
     DescribeResourceGroupResp describeResourceGroup(DescribeResourceGroupReq request) {
-        return retry(()->rgroupService.describeResourceGroup(this.getRpcStub(), request));
+        return rpcUtils.retry(()->rgroupService.describeResourceGroup(this.getRpcStub(), request));
     }
 
     /**
@@ -907,7 +798,7 @@ public class MilvusClientV2 {
      * @param request {@link TransferReplicaReq}
      */
     public void transferReplica(TransferReplicaReq request) {
-        retry(()->rgroupService.transferReplica(this.getRpcStub(), request));
+        rpcUtils.retry(()->rgroupService.transferReplica(this.getRpcStub(), request));
     }
 
     /////////////////////////////////////////////////////////////////////////////////////////////
@@ -919,7 +810,7 @@ public class MilvusClientV2 {
      * @param request create alias request
      */
     public void createAlias(CreateAliasReq request) {
-        retry(()->utilityService.createAlias(this.getRpcStub(), request));
+        rpcUtils.retry(()->utilityService.createAlias(this.getRpcStub(), request));
     }
     /**
      * drop aliases
@@ -927,7 +818,7 @@ public class MilvusClientV2 {
      * @param request drop alias request
      */
     public void dropAlias(DropAliasReq request) {
-        retry(()->utilityService.dropAlias(this.getRpcStub(), request));
+        rpcUtils.retry(()->utilityService.dropAlias(this.getRpcStub(), request));
     }
     /**
      * alter aliases
@@ -935,7 +826,7 @@ public class MilvusClientV2 {
      * @param request alter alias request
      */
     public void alterAlias(AlterAliasReq request) {
-        retry(()->utilityService.alterAlias(this.getRpcStub(), request));
+        rpcUtils.retry(()->utilityService.alterAlias(this.getRpcStub(), request));
     }
     /**
      * list aliases
@@ -944,7 +835,7 @@ public class MilvusClientV2 {
      * @return List of String alias names
      */
     public ListAliasResp listAliases(ListAliasesReq request) {
-        return retry(()->utilityService.listAliases(this.getRpcStub(), request));
+        return rpcUtils.retry(()->utilityService.listAliases(this.getRpcStub(), request));
     }
     /**
      * describe aliases
@@ -953,7 +844,7 @@ public class MilvusClientV2 {
      * @return DescribeAliasResp
      */
     public DescribeAliasResp describeAlias(DescribeAliasReq request) {
-        return retry(()->utilityService.describeAlias(this.getRpcStub(), request));
+        return rpcUtils.retry(()->utilityService.describeAlias(this.getRpcStub(), request));
     }
 
     /**
@@ -962,7 +853,7 @@ public class MilvusClientV2 {
      * @param request flush request
      */
     public void flush(FlushReq request) {
-        FlushResp response = retry(()->utilityService.flush(this.getRpcStub(), request));
+        FlushResp response = rpcUtils.retry(()->utilityService.flush(this.getRpcStub(), request));
 
         // The BlockingStub.flush() api returns immediately after the datanode set all growing segments to be "sealed".
         // The flush state becomes "Completed" after the datanode uploading them to S3 asynchronously.
@@ -982,7 +873,7 @@ public class MilvusClientV2 {
      * @return CompactResp
      */
     public CompactResp compact(CompactReq request) {
-        return retry(()->utilityService.compact(this.getRpcStub(), request));
+        return rpcUtils.retry(()->utilityService.compact(this.getRpcStub(), request));
     }
 
     /**
@@ -992,7 +883,7 @@ public class MilvusClientV2 {
      * @return GetCompactStateResp
      */
     public GetCompactionStateResp getCompactionState(GetCompactionStateReq request) {
-        return retry(()->utilityService.getCompactionState(this.getRpcStub(), request));
+        return rpcUtils.retry(()->utilityService.getCompactionState(this.getRpcStub(), request));
     }
 
     /**
@@ -1001,7 +892,7 @@ public class MilvusClientV2 {
      * @return String
      */
     public String getServerVersion() {
-        return retry(()->clientUtils.getServerVersion(this.getRpcStub()));
+        return rpcUtils.retry(()->clientUtils.getServerVersion(this.getRpcStub()));
     }
 
     /**

+ 116 - 0
src/main/java/io/milvus/v2/utils/RpcUtils.java

@@ -19,15 +19,25 @@
 
 package io.milvus.v2.utils;
 
+import io.grpc.StatusRuntimeException;
 import io.milvus.grpc.Status;
+import io.milvus.v2.client.RetryConfig;
 import io.milvus.v2.exception.ErrorCode;
 import io.milvus.v2.exception.MilvusClientException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.TimeUnit;
+
 public class RpcUtils {
 
     protected static final Logger logger = LoggerFactory.getLogger(RpcUtils.class);
+    private RetryConfig retryConfig = RetryConfig.builder().build();
+
+    public void retryConfig(RetryConfig retryConfig) {
+        this.retryConfig = retryConfig;
+    }
 
     public void handleResponse(String requestInfo, Status status) {
         // the server made a change for error code:
@@ -55,4 +65,110 @@ public class RpcUtils {
 
         logger.debug("{} successfully!", requestInfo);
     }
+
+    public <T> T retry(Callable<T> callable) {
+        int maxRetryTimes = retryConfig.getMaxRetryTimes();
+        // no retry, direct call the method
+        if (maxRetryTimes <= 1) {
+            try {
+                return callable.call();
+            } catch (StatusRuntimeException e) {
+                throw new MilvusClientException(ErrorCode.RPC_ERROR, e); // rpc error
+            } catch (MilvusClientException e) {
+                throw e; // server error or client error
+            } catch (Exception e) {
+                throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); // others error treated as client error
+            }
+        }
+
+        // method to check timeout
+        long begin = System.currentTimeMillis();
+        long maxRetryTimeoutMs = retryConfig.getMaxRetryTimeoutMs();
+        Callable<Boolean> timeoutChecker = ()->{
+            long current = System.currentTimeMillis();
+            long cost = (current - begin);
+            if (maxRetryTimeoutMs > 0 && cost >= maxRetryTimeoutMs) {
+                return Boolean.TRUE;
+            }
+            return Boolean.FALSE;
+        };
+
+        // retry within timeout
+        long retryIntervalMs = retryConfig.getInitialBackOffMs();
+        for (int k = 1; k <= maxRetryTimes; k++) {
+            try {
+                return callable.call();
+            } catch (StatusRuntimeException e) {
+                io.grpc.Status.Code code = e.getStatus().getCode();
+                if (code == io.grpc.Status.DEADLINE_EXCEEDED.getCode()
+                        || code == io.grpc.Status.PERMISSION_DENIED.getCode()
+                        || code == io.grpc.Status.UNAUTHENTICATED.getCode()
+                        || code == io.grpc.Status.INVALID_ARGUMENT.getCode()
+                        || code == io.grpc.Status.ALREADY_EXISTS.getCode()
+                        || code == io.grpc.Status.RESOURCE_EXHAUSTED.getCode()
+                        || code == io.grpc.Status.UNIMPLEMENTED.getCode()) {
+                    String msg = String.format("Encounter rpc error that cannot be retried, reason: %s", e);
+                    logger.error(msg);
+                    throw new MilvusClientException(ErrorCode.RPC_ERROR, msg); // throw rpc error
+                }
+
+                try {
+                    if (timeoutChecker.call() == Boolean.TRUE) {
+                        String msg = String.format("Retry timeout: %dms, maxRetry:%d, retries: %d, reason: %s",
+                                maxRetryTimeoutMs, maxRetryTimes, k, e.getMessage());
+                        logger.warn(msg);
+                        throw new MilvusClientException(ErrorCode.TIMEOUT, msg); // exit retry for timeout
+                    }
+                } catch (Exception ignored) {
+                }
+            } catch (MilvusClientException e) {
+                try {
+                    if (timeoutChecker.call() == Boolean.TRUE) {
+                        String msg = String.format("Retry timeout: %dms, maxRetry:%d, retries: %d, reason: %s",
+                                maxRetryTimeoutMs, maxRetryTimes, k, e.getMessage());
+                        logger.warn(msg);
+                        throw new MilvusClientException(ErrorCode.TIMEOUT, msg); // exit retry for timeout
+                    }
+                } catch (Exception ignored) {
+                }
+
+                // for server-side returned error, only retry for rate limit
+                // in new error codes of v2.3, rate limit error value is 8
+                if (retryConfig.isRetryOnRateLimit() &&
+                        (e.getLegacyServerCode() == io.milvus.grpc.ErrorCode.RateLimit.getNumber() ||
+                                e.getServerErrCode() == 8)) {
+                    // cannot be retried
+                } else {
+                    throw e; // exit retry, throw the error
+                }
+            } catch (Exception e) {
+                throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); // others error treated as client error
+            }
+
+            try {
+                if (k >= maxRetryTimes) {
+                    // finish retry loop, return the response of the last retry
+                    String msg = String.format("Finish %d retry times, stop retry", maxRetryTimes);
+                    logger.warn(msg);
+                    throw new MilvusClientException(ErrorCode.TIMEOUT, msg); // exceed max time, exit retry
+                } else {
+                    // sleep for interval
+                    // print log, follow the pymilvus logic
+                    if (k > 3) {
+                        logger.warn(String.format("Retry(%d) with interval %dms", k, retryIntervalMs));
+                    }
+                    TimeUnit.MILLISECONDS.sleep(retryIntervalMs);
+                }
+
+                // reset the next interval value
+                retryIntervalMs = retryIntervalMs*retryConfig.getBackOffMultiplier();
+                if (retryIntervalMs > retryConfig.getMaxBackOffMs()) {
+                    retryIntervalMs = retryConfig.getMaxBackOffMs();
+                }
+            } catch (Exception ignored) {
+            }
+        }
+
+        return null;
+    }
 }