浏览代码

gRPC timeout support

jianghua 4 年之前
父节点
当前提交
2908a6828d

+ 9 - 30
src/main/java/io/milvus/client/MilvusClient.java

@@ -26,6 +26,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.util.List;
 import java.util.Properties;
+import java.util.concurrent.TimeUnit;
 import java.util.function.Supplier;
 
 /** The Milvus Client Interface */
@@ -54,40 +55,18 @@ public interface MilvusClient {
   }
 
   /**
-   * Connects to Milvus server
-   *
-   * @param connectParam the <code>ConnectParam</code> object
-   * <pre>
-   * example usage:
-   * <code>
-   * ConnectParam connectParam = new ConnectParam.Builder()
-   *                                             .withHost("localhost")
-   *                                             .withPort(19530)
-   *                                             .withConnectTimeout(10, TimeUnit.SECONDS)
-   *                                             .withKeepAliveTime(Long.MAX_VALUE, TimeUnit.NANOSECONDS)
-   *                                             .withKeepAliveTimeout(20, TimeUnit.SECONDS)
-   *                                             .keepAliveWithoutCalls(false)
-   *                                             .withIdleTimeout(24, TimeUnit.HOURS)
-   *                                             .build();
-   * </code>
-   * </pre>
-   *
-   * @return <code>Response</code>
-   * @throws ConnectFailedException if client failed to connect
-   * @see ConnectParam
-   * @see Response
-   * @see ConnectFailedException
+   * Close this MilvusClient. Wait at most 1 minute for graceful shutdown.
    */
-  Response connect(ConnectParam connectParam) throws ConnectFailedException;
+  default void close() {
+    close(TimeUnit.MINUTES.toSeconds(1));
+  }
 
   /**
-   * Disconnects from Milvus server
-   *
-   * @return <code>Response</code>
-   * @throws InterruptedException if disconnect interrupted
-   * @see Response
+   * Close this MilvusClient. Wait at most `maxWaitSeconds` for graceful shutdown.
    */
-  Response disconnect() throws InterruptedException;
+  void close(long maxWaitSeconds);
+
+  MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit);
 
   /**
    * Creates collection specified by <code>collectionMapping</code>

+ 178 - 143
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -24,11 +24,25 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.protobuf.ByteString;
-import io.grpc.ConnectivityState;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.MethodDescriptor;
 import io.grpc.StatusRuntimeException;
+import io.milvus.client.exception.InitializationException;
+import io.milvus.client.exception.UnsupportedServerVersion;
 import io.milvus.grpc.*;
+import org.apache.commons.lang3.ArrayUtils;
+import org.json.JSONArray;
+import org.json.JSONException;
+import org.json.JSONObject;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -39,114 +53,145 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
-import javax.annotation.Nonnull;
-import org.apache.commons.lang3.ArrayUtils;
-import org.json.JSONArray;
-import org.json.JSONException;
-import org.json.JSONObject;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 /** Actual implementation of interface <code>MilvusClient</code> */
-public class MilvusGrpcClient implements MilvusClient {
+public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
 
   private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
-  private final String extraParamKey = "params";
-  private ManagedChannel channel = null;
-  private MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub = null;
-  private MilvusServiceGrpc.MilvusServiceFutureStub futureStub = null;
+  private static final String SUPPORTED_SERVER_VERSION = "0.11";
+
+  private final ManagedChannel channel;
+  private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
+  private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
+
+  public MilvusGrpcClient(ConnectParam connectParam) {
+    channel = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
+        .usePlaintext()
+        .maxInboundMessageSize(Integer.MAX_VALUE)
+        .keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .keepAliveTimeout(connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
+        .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .build();
+    blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
+    futureStub = MilvusServiceGrpc.newFutureStub(channel);
+    try {
+      Response response = getServerVersion();
+      if (response.ok()) {
+        String serverVersion = response.getMessage();
+        if (!serverVersion.matches("^" + SUPPORTED_SERVER_VERSION + "(\\..*)?$")) {
+          throw new UnsupportedServerVersion(connectParam.getHost(), SUPPORTED_SERVER_VERSION, serverVersion);
+        }
+      } else {
+        throw new InitializationException(connectParam.getHost(), response.getMessage());
+      }
+    } catch (Throwable t) {
+      channel.shutdownNow();
+      throw t;
+    }
+  }
 
-  ////////////////////// Constructor //////////////////////
-  public MilvusGrpcClient() {}
+  @Override
+  protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
+    return blockingStub;
+  }
 
-  /////////////////////// Client Calls///////////////////////
+  @Override
+  protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
+    return futureStub;
+  }
 
   @Override
-  public Response connect(ConnectParam connectParam) throws ConnectFailedException {
-    if (channel != null && !(channel.isShutdown() || channel.isTerminated())) {
-      logWarning("Channel is not shutdown or terminated");
-      throw new ConnectFailedException("Channel is not shutdown or terminated");
+  protected boolean maybeAvailable() {
+    switch (channel.getState(false)) {
+      case IDLE:
+      case CONNECTING:
+      case READY:
+        return true;
+      default:
+        return false;
     }
+  }
 
+  @Override
+  public void close(long maxWaitSeconds) {
+    channel.shutdown();
     try {
+      channel.awaitTermination(maxWaitSeconds, TimeUnit.SECONDS);
+    } catch (InterruptedException ex) {
+      logger.warn("Milvus client close interrupted");
+      channel.shutdownNow();
+      Thread.currentThread().interrupt();
+    }
+  }
 
-      channel =
-          ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
-              .usePlaintext()
-              .maxInboundMessageSize(Integer.MAX_VALUE)
-              .keepAliveTime(
-                  connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .keepAliveTimeout(
-                  connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
-              .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .build();
-
-      channel.getState(true);
+  public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
+    final long timeoutMillis = timeoutUnit.toMillis(timeout);
+    final TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(timeoutMillis);
+    final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub =
+        this.blockingStub.withInterceptors(timeoutInterceptor);
+    final MilvusServiceGrpc.MilvusServiceFutureStub futureStub =
+        this.futureStub.withInterceptors(timeoutInterceptor);
 
-      long timeout = connectParam.getConnectTimeout(TimeUnit.MILLISECONDS);
-      logInfo("Trying to connect...Timeout in {} ms", timeout);
+    return new AbstractMilvusGrpcClient() {
 
-      final long checkFrequency = 100; // ms
-      while (channel.getState(false) != ConnectivityState.READY) {
-        if (timeout <= 0) {
-          logError("Connect timeout!");
-          throw new ConnectFailedException("Connect timeout");
-        }
-        TimeUnit.MILLISECONDS.sleep(checkFrequency);
-        timeout -= checkFrequency;
+      @Override
+      protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
+        return blockingStub;
       }
 
-      blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
-      futureStub = MilvusServiceGrpc.newFutureStub(channel);
+      @Override
+      protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
+        return futureStub;
+      }
 
-      // check server version
-      String serverVersion = getServerVersion().getMessage();
-      if (!serverVersion.contains("0.11.")) {
-        logError(
-            "Connect failed! Server version {} does not match SDK version 0.9.0", serverVersion);
-        throw new ConnectFailedException("Failed to connect to Milvus server.");
+      @Override
+      protected boolean maybeAvailable() {
+        return MilvusGrpcClient.this.maybeAvailable();
       }
 
-    } catch (Exception e) {
-      if (!(e instanceof ConnectFailedException)) {
-        logError("Connect failed! {}", e.toString());
+      @Override
+      public void close(long maxWaitSeconds) {
+        MilvusGrpcClient.this.close(maxWaitSeconds);
       }
-      throw new ConnectFailedException("Exception occurred: " + e.toString());
-    }
 
-    logInfo(
-        "Connection established successfully to host={}, port={}",
-        connectParam.getHost(),
-        String.valueOf(connectParam.getPort()));
-    return new Response(Response.Status.SUCCESS);
+      @Override
+      public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
+        return MilvusGrpcClient.this.withTimeout(timeout, timeoutUnit);
+      }
+    };
   }
 
-  @Override
-  public Response disconnect() throws InterruptedException {
-    if (!channelIsReadyOrIdle()) {
-      logWarning("You are not connected to Milvus server");
-      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
-    } else {
-      try {
-        if (channel.shutdown().awaitTermination(60, TimeUnit.SECONDS)) {
-          logInfo("Channel terminated");
-        } else {
-          logError("Encountered error when terminating channel");
-          return new Response(Response.Status.RPC_ERROR);
-        }
-      } catch (InterruptedException e) {
-        logError("Exception thrown when terminating channel: {}", e.toString());
-        throw e;
-      }
+  private static class TimeoutInterceptor implements ClientInterceptor {
+    private long timeoutMillis;
+
+    TimeoutInterceptor(long timeoutMillis) {
+      this.timeoutMillis = timeoutMillis;
+    }
+
+    @Override
+    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+      return next.newCall(method, callOptions.withDeadlineAfter(timeoutMillis, TimeUnit.MILLISECONDS));
     }
-    return new Response(Response.Status.SUCCESS);
   }
+}
+
+abstract class AbstractMilvusGrpcClient implements MilvusClient {
+  private static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
+
+  private final String extraParamKey = "params";
+
+  protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();
+
+  protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();
+
+  protected abstract boolean maybeAvailable();
 
   @Override
   public Response createCollection(@Nonnull CollectionMapping collectionMapping) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -191,7 +236,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createCollection(request);
+      response = blockingStub().createCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created collection successfully!\n{}", collectionMapping.toString());
@@ -215,7 +260,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasCollectionResponse hasCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new HasCollectionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -224,7 +269,7 @@ public class MilvusGrpcClient implements MilvusClient {
     BoolReply response;
 
     try {
-      response = blockingStub.hasCollection(request);
+      response = blockingStub().hasCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("hasCollection `{}` = {}", collectionName, response.getBoolReply());
@@ -248,7 +293,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -257,7 +302,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropCollection(request);
+      response = blockingStub().dropCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped collection `{}` successfully!", collectionName);
@@ -276,7 +321,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createIndex(@Nonnull Index index) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -309,7 +354,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createIndex(request);
+      response = blockingStub().createIndex(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created index successfully!\n{}", index.toString());
@@ -328,7 +373,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> createIndexAsync(@Nonnull Index index) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -360,7 +405,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.createIndex(request);
+    response = futureStub().createIndex(request);
 
     Futures.addCallback(
         response,
@@ -388,7 +433,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -399,7 +444,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createPartition(request);
+      response = blockingStub().createPartition(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created partition `{}` in collection `{}` successfully!", tag, collectionName);
@@ -422,7 +467,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasPartitionResponse hasPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new HasPartitionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -432,7 +477,7 @@ public class MilvusGrpcClient implements MilvusClient {
     BoolReply response;
 
     try {
-      response = blockingStub.hasPartition(request);
+      response = blockingStub().hasPartition(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -463,7 +508,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListPartitionsResponse listPartitions(String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListPartitionsResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -473,7 +518,7 @@ public class MilvusGrpcClient implements MilvusClient {
     PartitionList response;
 
     try {
-      response = blockingStub.showPartitions(request);
+      response = blockingStub().showPartitions(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -500,7 +545,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -510,7 +555,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropPartition(request);
+      response = blockingStub().dropPartition(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped partition `{}` in collection `{}` successfully!", tag, collectionName);
@@ -534,7 +579,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @SuppressWarnings("unchecked")
   public InsertResponse insert(@Nonnull InsertParam insertParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new InsertResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -621,7 +666,7 @@ public class MilvusGrpcClient implements MilvusClient {
     EntityIds response;
 
     try {
-      response = blockingStub.insert(request);
+      response = blockingStub().insert(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -649,7 +694,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @SuppressWarnings("unchecked")
   public ListenableFuture<InsertResponse> insertAsync(@Nonnull InsertParam insertParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(
           new InsertResponse(
@@ -739,7 +784,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<EntityIds> response;
 
-    response = futureStub.insert(request);
+    response = futureStub().insert(request);
 
     Futures.addCallback(
         response,
@@ -783,7 +828,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public SearchResponse search(@Nonnull SearchParam searchParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
@@ -889,7 +934,7 @@ public class MilvusGrpcClient implements MilvusClient {
     QueryResult response;
 
     try {
-      response = blockingStub.search(request);
+      response = blockingStub().search(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         SearchResponse searchResponse = buildSearchResponse(response);
@@ -918,7 +963,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
@@ -1023,7 +1068,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<QueryResult> response;
 
-    response = futureStub.search(request);
+    response = futureStub().search(request);
 
     Futures.addCallback(
         response,
@@ -1068,7 +1113,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public GetCollectionInfoResponse getCollectionInfo(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetCollectionInfoResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
@@ -1078,7 +1123,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Mapping response;
 
     try {
-      response = blockingStub.describeCollection(request);
+      response = blockingStub().describeCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         String extraParam = "";
@@ -1128,7 +1173,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListCollectionsResponse listCollections() {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListCollectionsResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -1138,7 +1183,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionNameList response;
 
     try {
-      response = blockingStub.showCollections(request);
+      response = blockingStub().showCollections(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         List<String> collectionNames = response.getCollectionNamesList();
@@ -1162,7 +1207,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public CountEntitiesResponse countEntities(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new CountEntitiesResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), 0);
     }
@@ -1171,7 +1216,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionRowCount response;
 
     try {
-      response = blockingStub.countCollection(request);
+      response = blockingStub().countCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         long collectionRowCount = response.getCollectionRowCount();
@@ -1206,7 +1251,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   public Response command(@Nonnull String command) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1215,7 +1260,7 @@ public class MilvusGrpcClient implements MilvusClient {
     StringReply response;
 
     try {
-      response = blockingStub.cmd(request);
+      response = blockingStub().cmd(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Command `{}`: {}", command, response.getStringReply());
@@ -1235,7 +1280,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response loadCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1244,7 +1289,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.preloadCollection(request);
+      response = blockingStub().preloadCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Loaded collection `{}` successfully!", collectionName);
@@ -1263,7 +1308,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropIndex(String collectionName, String fieldName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1276,7 +1321,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropIndex(request);
+      response = blockingStub().dropIndex(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped index for collection `{}` successfully!", collectionName);
@@ -1294,7 +1339,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response getCollectionStats(String collectionName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1303,7 +1348,7 @@ public class MilvusGrpcClient implements MilvusClient {
     io.milvus.grpc.CollectionInfo response;
 
     try {
-      response = blockingStub.showCollectionInfo(request);
+      response = blockingStub().showCollectionInfo(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("getCollectionStats for `{}` returned successfully!", collectionName);
@@ -1325,7 +1370,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids, List<String> fieldNames) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetEntityByIDResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -1340,7 +1385,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Entities response;
 
     try {
-      response = blockingStub.getEntityByID(request);
+      response = blockingStub().getEntityByID(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
@@ -1411,7 +1456,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public ListIDInSegmentResponse listIDInSegment(String collectionName, Long segmentId) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListIDInSegmentResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -1425,7 +1470,7 @@ public class MilvusGrpcClient implements MilvusClient {
     EntityIds response;
 
     try {
-      response = blockingStub.getEntityIDs(request);
+      response = blockingStub().getEntityIDs(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
@@ -1456,7 +1501,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response deleteEntityByID(String collectionName, List<Long> ids) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1466,7 +1511,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.deleteByID(request);
+      response = blockingStub().deleteByID(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("deleteEntityByID in collection `{}` completed successfully!", collectionName);
@@ -1485,7 +1530,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response flush(List<String> collectionNames) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1494,7 +1539,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.flush(request);
+      response = blockingStub().flush(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Flushed collection {} successfully!", collectionNames);
@@ -1513,7 +1558,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> flushAsync(@Nonnull List<String> collectionNames) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -1522,7 +1567,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.flush(request);
+    response = futureStub().flush(request);
 
     Futures.addCallback(
         response,
@@ -1571,7 +1616,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response compact(CompactParam compactParam) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1584,7 +1629,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.compact(request);
+      response = blockingStub().compact(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Compacted collection `{}` successfully!", compactParam.getCollectionName());
@@ -1604,7 +1649,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> compactAsync(@Nonnull CompactParam compactParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -1617,7 +1662,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.compact(request);
+    response = futureStub().compact(request);
 
     Futures.addCallback(
         response,
@@ -1730,16 +1775,6 @@ public class MilvusGrpcClient implements MilvusClient {
     return searchResponse;
   }
 
-  private boolean channelIsReadyOrIdle() {
-    if (channel == null) {
-      return false;
-    }
-    ConnectivityState connectivityState = channel.getState(false);
-    return connectivityState == ConnectivityState.READY
-        || connectivityState
-            == ConnectivityState.IDLE; // Since a new RPC would take the channel out of idle mode
-  }
-
   private String kvListToString(List<KeyValuePair> kv) {
     JSONObject jsonObject = new JSONObject();
     for (KeyValuePair keyValuePair : kv) {

+ 7 - 0
src/main/java/io/milvus/client/exception/InitializationException.java

@@ -0,0 +1,7 @@
+package io.milvus.client.exception;
+
+public class InitializationException extends MilvusException {
+  public InitializationException(String host, String message) {
+    super(false, host + ": " + message);
+  }
+}

+ 24 - 0
src/main/java/io/milvus/client/exception/MilvusException.java

@@ -0,0 +1,24 @@
+package io.milvus.client.exception;
+
+public class MilvusException extends RuntimeException {
+  private boolean fillInStackTrace;
+
+  MilvusException(boolean fillInStackTrace) {
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  MilvusException(boolean fillInStackTrace, Throwable cause) {
+    super(cause);
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  MilvusException(boolean fillInStackTrace, String message) {
+    super(message);
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  @Override
+  public synchronized Throwable fillInStackTrace() {
+    return fillInStackTrace ? super.fillInStackTrace() : this;
+  }
+}

+ 22 - 0
src/main/java/io/milvus/client/exception/UnsupportedServerVersion.java

@@ -0,0 +1,22 @@
+package io.milvus.client.exception;
+
+import io.milvus.client.MilvusClient;
+
+public class UnsupportedServerVersion extends MilvusException {
+  private String host;
+  private String expect;
+  private String actual;
+
+  public UnsupportedServerVersion(String host, String expect, String actual) {
+    super(false);
+    this.host = host;
+    this.expect = expect;
+    this.actual = actual;
+  }
+
+  @Override
+  public String getMessage() {
+    return String.format("%s: Milvus client %s is expected to work with Milvus server %s, but the version of the connected server is %s",
+        host, MilvusClient.clientVersion, expect, actual);
+  }
+}

+ 42 - 14
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -25,6 +25,8 @@ import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import io.milvus.client.InsertParam.Builder;
 import io.milvus.client.Response.Status;
+import io.milvus.client.exception.InitializationException;
+import io.milvus.client.exception.UnsupportedServerVersion;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.text.RandomStringGenerator;
 import org.checkerframework.checker.nullness.compatqual.NullableDecl;
@@ -68,10 +70,6 @@ class ContainerMilvusClientTest extends MilvusClientTest {
   protected ConnectParam.Builder connectParamBuilder() {
     return connectParamBuilder(milvusContainer);
   }
-
-  private ConnectParam.Builder connectParamBuilder(GenericContainer milvusContainer) {
-    return connectParamBuilder(milvusContainer.getHost(), milvusContainer.getFirstMappedPort());
-  }
 }
 
 @Testcontainers
@@ -90,6 +88,10 @@ class MilvusClientTest {
     return connectParamBuilder("localhost", 19530);
   }
 
+  protected ConnectParam.Builder connectParamBuilder(GenericContainer milvusContainer) {
+    return connectParamBuilder(milvusContainer.getHost(), milvusContainer.getFirstMappedPort());
+  }
+
   protected ConnectParam.Builder connectParamBuilder(String host, int port) {
     return new ConnectParam.Builder().withHost(host).withPort(port);
   }
@@ -171,9 +173,8 @@ class MilvusClientTest {
   @org.junit.jupiter.api.BeforeEach
   void setUp() throws Exception {
 
-    client = new MilvusGrpcClient();
     ConnectParam connectParam = connectParamBuilder().build();
-    client.connect(connectParam);
+    client = new MilvusGrpcClient(connectParam);
 
     generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
     randomCollectionName = generator.generate(10);
@@ -195,18 +196,17 @@ class MilvusClientTest {
   }
 
   @org.junit.jupiter.api.AfterEach
-  void tearDown() throws InterruptedException {
-    assertTrue(client.dropCollection(randomCollectionName).ok());
-    client.disconnect();
+  void tearDown() {
+    client.dropCollection(randomCollectionName);
+    client.close();
   }
 
   @org.junit.jupiter.api.Test
-  void idleTest() throws InterruptedException, ConnectFailedException {
-    MilvusClient client = new MilvusGrpcClient();
+  void idleTest() throws InterruptedException {
     ConnectParam connectParam = connectParamBuilder()
         .withIdleTimeout(1, TimeUnit.SECONDS)
         .build();
-    client.connect(connectParam);
+    MilvusClient client = new MilvusGrpcClient(connectParam);
     TimeUnit.SECONDS.sleep(2);
     // A new RPC would take the channel out of idle mode
     assertTrue(client.listCollections().ok());
@@ -247,9 +247,37 @@ class MilvusClientTest {
 
   @org.junit.jupiter.api.Test
   void connectUnreachableHost() {
-    MilvusClient client = new MilvusGrpcClient();
     ConnectParam connectParam = connectParamBuilder("250.250.250.250", 19530).build();
-    assertThrows(ConnectFailedException.class, () -> client.connect(connectParam));
+    assertThrows(InitializationException.class, () -> new MilvusGrpcClient(connectParam));
+  }
+
+  @org.junit.jupiter.api.Test
+  void unsupportedServerVersion() {
+    GenericContainer unsupportedMilvusContainer =
+        new GenericContainer("milvusdb/milvus:0.9.1-cpu-d052920-e04ed5")
+            .withExposedPorts(19530);
+    try {
+      unsupportedMilvusContainer.start();
+      ConnectParam connectParam = connectParamBuilder(unsupportedMilvusContainer).build();
+      assertThrows(UnsupportedServerVersion.class, () -> new MilvusGrpcClient(connectParam));
+    } finally {
+      unsupportedMilvusContainer.stop();
+    }
+  }
+
+  @org.junit.jupiter.api.Test
+  void grpcTimeout() {
+    insert();
+    MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
+    Response response = timeoutClient.createIndex(
+        new Index.Builder(randomCollectionName, "float_vec")
+            .withParamsInJson(new JsonBuilder()
+                .param("index_type", "IVF_FLAT")
+                .param("metric_type", "L2")
+                .indexParam("nlist", 2048)
+                .build())
+            .build());
+    assertEquals(Response.Status.RPC_ERROR, response.getStatus());
   }
 
   @org.junit.jupiter.api.Test