Ver código fonte

Support GRPC timeout

jianghua 4 anos atrás
pai
commit
2ba9b11e02

+ 9 - 30
src/main/java/io/milvus/client/MilvusClient.java

@@ -21,6 +21,7 @@ package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
 import java.util.List;
+import java.util.concurrent.TimeUnit;
 
 /** The Milvus Client Interface */
 public interface MilvusClient {
@@ -33,40 +34,18 @@ public interface MilvusClient {
   }
 
   /**
-   * Connects to Milvus server
-   *
-   * @param connectParam the <code>ConnectParam</code> object
-   *     <pre>
-   * example usage:
-   * <code>
-   * ConnectParam connectParam = new ConnectParam.Builder()
-   *                                             .withHost("localhost")
-   *                                             .withPort(19530)
-   *                                             .withConnectTimeout(10, TimeUnit.SECONDS)
-   *                                             .withKeepAliveTime(Long.MAX_VALUE, TimeUnit.NANOSECONDS)
-   *                                             .withKeepAliveTimeout(20, TimeUnit.SECONDS)
-   *                                             .keepAliveWithoutCalls(false)
-   *                                             .withIdleTimeout(24, TimeUnit.HOURS)
-   *                                             .build();
-   * </code>
-   * </pre>
-   *
-   * @return <code>Response</code>
-   * @throws ConnectFailedException if client failed to connect
-   * @see ConnectParam
-   * @see Response
-   * @see ConnectFailedException
+   * Close this MilvusClient. Wait at most 1 minute for graceful shutdown.
    */
-  Response connect(ConnectParam connectParam) throws ConnectFailedException;
+  default void close() {
+    close(TimeUnit.MINUTES.toSeconds(1));
+  }
 
   /**
-   * Disconnects from Milvus server
-   *
-   * @return <code>Response</code>
-   * @throws InterruptedException
-   * @see Response
+   * Close this MilvusClient. Wait at most `maxWaitSeconds` for graceful shutdown.
    */
-  Response disconnect() throws InterruptedException;
+  void close(long maxWaitSeconds);
+
+  MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit);
 
   /**
    * Creates collection specified by <code>collectionMapping</code>

+ 178 - 143
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -24,10 +24,16 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.protobuf.ByteString;
-import io.grpc.ConnectivityState;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.MethodDescriptor;
 import io.grpc.StatusRuntimeException;
+import io.milvus.client.exception.InitializationException;
+import io.milvus.client.exception.UnsupportedServerVersion;
 import io.milvus.grpc.*;
 import java.nio.Buffer;
 import java.nio.ByteBuffer;
@@ -40,106 +46,145 @@ import javax.annotation.Nonnull;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-/** Actual implementation of interface <code>MilvusClient</code> */
-public class MilvusGrpcClient implements MilvusClient {
+public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
+  private static String SUPPORTED_SERVER_VERSION = "0.10";
+  private final ManagedChannel channel;
+  private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
+  private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
+
+  public MilvusGrpcClient(ConnectParam connectParam) {
+    channel = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
+        .usePlaintext()
+        .maxInboundMessageSize(Integer.MAX_VALUE)
+        .keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .keepAliveTimeout(connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
+        .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .build();
+    blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
+    futureStub = MilvusServiceGrpc.newFutureStub(channel);
 
-  private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
-  private final String extraParamKey = "params";
-  private ManagedChannel channel = null;
-  private MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub = null;
-  private MilvusServiceGrpc.MilvusServiceFutureStub futureStub = null;
+    try {
+      Response response = getServerVersion();
+      if (response.ok()) {
+        String serverVersion = getServerVersion().getMessage();
+        if (!serverVersion.matches("^" + SUPPORTED_SERVER_VERSION + "(\\..*)?$")) {
+          throw new UnsupportedServerVersion(connectParam.getHost(), SUPPORTED_SERVER_VERSION, serverVersion);
+        }
+      } else {
+        throw new InitializationException(connectParam.getHost(), response.getMessage());
+      }
+    } catch (Throwable t) {
+      channel.shutdownNow();
+      throw t;
+    }
+  }
 
-  ////////////////////// Constructor //////////////////////
-  public MilvusGrpcClient() {}
+  @Override
+  protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
+    return blockingStub;
+  }
 
-  /////////////////////// Client Calls///////////////////////
+  @Override
+  protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
+    return futureStub;
+  }
 
   @Override
-  public Response connect(ConnectParam connectParam) throws ConnectFailedException {
-    if (channel != null && !(channel.isShutdown() || channel.isTerminated())) {
-      logWarning("Channel is not shutdown or terminated");
-      throw new ConnectFailedException("Channel is not shutdown or terminated");
+  protected boolean maybeAvailable() {
+    switch (channel.getState(false)) {
+      case IDLE:
+      case CONNECTING:
+      case READY:
+        return true;
+      default:
+        return false;
     }
+  }
 
-    try {
+  @Override
+  public void close(long maxWaitSeconds) {
+    channel.shutdown();
+    long now = System.nanoTime();
+    long deadline = now + TimeUnit.SECONDS.toNanos(maxWaitSeconds);
+    while (now < deadline && !channel.isTerminated()) {
+      try {
+        channel.awaitTermination(deadline - now, TimeUnit.NANOSECONDS);
+      } catch (InterruptedException ex) {
+      }
+    }
+    if (!channel.isTerminated()) {
+      channel.shutdownNow();
+    }
+  }
 
-      channel =
-          ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
-              .usePlaintext()
-              .maxInboundMessageSize(Integer.MAX_VALUE)
-              .keepAliveTime(
-                  connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .keepAliveTimeout(
-                  connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
-              .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .build();
-
-      channel.getState(true);
-
-      long timeout = connectParam.getConnectTimeout(TimeUnit.MILLISECONDS);
-      logInfo("Trying to connect...Timeout in {} ms", timeout);
-
-      final long checkFrequency = 100; // ms
-      while (channel.getState(false) != ConnectivityState.READY) {
-        if (timeout <= 0) {
-          logError("Connect timeout!");
-          throw new ConnectFailedException("Connect timeout");
-        }
-        TimeUnit.MILLISECONDS.sleep(checkFrequency);
-        timeout -= checkFrequency;
+  public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
+    final long timeoutMillis = timeoutUnit.toMillis(timeout);
+    final TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(timeoutMillis);
+    final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub =
+        this.blockingStub.withInterceptors(timeoutInterceptor);
+    final MilvusServiceGrpc.MilvusServiceFutureStub futureStub =
+        this.futureStub.withInterceptors(timeoutInterceptor);
+
+    return new AbstractMilvusGrpcClient() {
+
+      @Override
+      protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
+        return blockingStub;
       }
 
-      blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
-      futureStub = MilvusServiceGrpc.newFutureStub(channel);
+      @Override
+      protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
+        return futureStub;
+      }
 
-      // check server version
-      String serverVersion = getServerVersion().getMessage();
-      if (!serverVersion.contains("0.10.")) {
-        logError(
-            "Connect failed! Server version {} does not match SDK version 0.8.4", serverVersion);
-        throw new ConnectFailedException("Failed to connect to Milvus server.");
+      @Override
+      protected boolean maybeAvailable() {
+        return MilvusGrpcClient.this.maybeAvailable();
       }
 
-    } catch (Exception e) {
-      if (!(e instanceof ConnectFailedException)) {
-        logError("Connect failed! {}", e.toString());
+      @Override
+      public void close(long maxWaitSeconds) {
+        MilvusGrpcClient.this.close(maxWaitSeconds);
       }
-      throw new ConnectFailedException("Exception occurred: " + e.toString());
-    }
 
-    logInfo(
-        "Connection established successfully to host={}, port={}",
-        connectParam.getHost(),
-        String.valueOf(connectParam.getPort()));
-    return new Response(Response.Status.SUCCESS);
+      @Override
+      public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
+        return MilvusGrpcClient.this.withTimeout(timeout, timeoutUnit);
+      }
+    };
   }
 
-  @Override
-  public Response disconnect() throws InterruptedException {
-    if (!channelIsReadyOrIdle()) {
-      logWarning("You are not connected to Milvus server");
-      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
-    } else {
-      try {
-        if (channel.shutdown().awaitTermination(60, TimeUnit.SECONDS)) {
-          logInfo("Channel terminated");
-        } else {
-          logError("Encountered error when terminating channel");
-          return new Response(Response.Status.RPC_ERROR);
-        }
-      } catch (InterruptedException e) {
-        logError("Exception thrown when terminating channel: {}", e.toString());
-        throw e;
-      }
+  private static class TimeoutInterceptor implements ClientInterceptor {
+    private long timeoutMillis;
+
+    TimeoutInterceptor(long timeoutMillis) {
+      this.timeoutMillis = timeoutMillis;
+    }
+
+    @Override
+    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+      return next.newCall(method, callOptions.withDeadlineAfter(timeoutMillis, TimeUnit.MILLISECONDS));
     }
-    return new Response(Response.Status.SUCCESS);
   }
+}
+
+abstract class AbstractMilvusGrpcClient implements MilvusClient {
+  private static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
+
+  private final String extraParamKey = "params";
+
+  protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();
+
+  protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();
+
+  protected abstract boolean maybeAvailable();
 
   @Override
   public Response createCollection(@Nonnull CollectionMapping collectionMapping) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -155,7 +200,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createCollection(request);
+      response = blockingStub().createCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created collection successfully!\n{}", collectionMapping.toString());
@@ -179,7 +224,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasCollectionResponse hasCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new HasCollectionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -188,7 +233,7 @@ public class MilvusGrpcClient implements MilvusClient {
     BoolReply response;
 
     try {
-      response = blockingStub.hasCollection(request);
+      response = blockingStub().hasCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("hasCollection `{}` = {}", collectionName, response.getBoolReply());
@@ -212,7 +257,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -221,7 +266,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropCollection(request);
+      response = blockingStub().dropCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped collection `{}` successfully!", collectionName);
@@ -240,7 +285,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createIndex(@Nonnull Index index) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -257,7 +302,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createIndex(request);
+      response = blockingStub().createIndex(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created index successfully!\n{}", index.toString());
@@ -276,7 +321,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> createIndexAsync(@Nonnull Index index) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -292,7 +337,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.createIndex(request);
+    response = futureStub().createIndex(request);
 
     Futures.addCallback(
         response,
@@ -320,7 +365,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -331,7 +376,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createPartition(request);
+      response = blockingStub().createPartition(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created partition `{}` in collection `{}` successfully!", tag, collectionName);
@@ -354,7 +399,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasPartitionResponse hasPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new HasPartitionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -364,7 +409,7 @@ public class MilvusGrpcClient implements MilvusClient {
     BoolReply response;
 
     try {
-      response = blockingStub.hasPartition(request);
+      response = blockingStub().hasPartition(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -395,7 +440,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListPartitionsResponse listPartitions(String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListPartitionsResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -405,7 +450,7 @@ public class MilvusGrpcClient implements MilvusClient {
     PartitionList response;
 
     try {
-      response = blockingStub.showPartitions(request);
+      response = blockingStub().showPartitions(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -432,7 +477,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -442,7 +487,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropPartition(request);
+      response = blockingStub().dropPartition(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped partition `{}` in collection `{}` successfully!", tag, collectionName);
@@ -465,7 +510,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public InsertResponse insert(@Nonnull InsertParam insertParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new InsertResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -484,7 +529,7 @@ public class MilvusGrpcClient implements MilvusClient {
     VectorIds response;
 
     try {
-      response = blockingStub.insert(request);
+      response = blockingStub().insert(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -511,7 +556,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<InsertResponse> insertAsync(@Nonnull InsertParam insertParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(
           new InsertResponse(
@@ -531,7 +576,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<VectorIds> response;
 
-    response = futureStub.insert(request);
+    response = futureStub().insert(request);
 
     Futures.addCallback(
         response,
@@ -575,7 +620,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public SearchResponse search(@Nonnull SearchParam searchParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
@@ -603,7 +648,7 @@ public class MilvusGrpcClient implements MilvusClient {
     TopKQueryResult response;
 
     try {
-      response = blockingStub.search(request);
+      response = blockingStub().search(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         SearchResponse searchResponse = buildSearchResponse(response);
@@ -632,7 +677,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
@@ -659,7 +704,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<TopKQueryResult> response;
 
-    response = futureStub.search(request);
+    response = futureStub().search(request);
 
     Futures.addCallback(
         response,
@@ -704,7 +749,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public GetCollectionInfoResponse getCollectionInfo(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetCollectionInfoResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
@@ -714,7 +759,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionSchema response;
 
     try {
-      response = blockingStub.describeCollection(request);
+      response = blockingStub().describeCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         CollectionMapping collectionMapping =
@@ -746,7 +791,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListCollectionsResponse listCollections() {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListCollectionsResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -756,7 +801,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionNameList response;
 
     try {
-      response = blockingStub.showCollections(request);
+      response = blockingStub().showCollections(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         List<String> collectionNames = response.getCollectionNamesList();
@@ -780,7 +825,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public CountEntitiesResponse countEntities(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new CountEntitiesResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), 0);
     }
@@ -789,7 +834,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionRowCount response;
 
     try {
-      response = blockingStub.countCollection(request);
+      response = blockingStub().countCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         long collectionRowCount = response.getCollectionRowCount();
@@ -824,7 +869,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   public Response command(@Nonnull String command) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -833,7 +878,7 @@ public class MilvusGrpcClient implements MilvusClient {
     StringReply response;
 
     try {
-      response = blockingStub.cmd(request);
+      response = blockingStub().cmd(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Command `{}`: {}", command, response.getStringReply());
@@ -853,7 +898,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response loadCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -862,7 +907,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.preloadCollection(request);
+      response = blockingStub().preloadCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Loaded collection `{}` successfully!", collectionName);
@@ -881,7 +926,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public GetIndexInfoResponse getIndexInfo(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetIndexInfoResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
     }
@@ -890,7 +935,7 @@ public class MilvusGrpcClient implements MilvusClient {
     IndexParam response;
 
     try {
-      response = blockingStub.describeIndex(request);
+      response = blockingStub().describeIndex(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         String extraParam = "";
@@ -927,7 +972,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropIndex(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -936,7 +981,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropIndex(request);
+      response = blockingStub().dropIndex(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped index for collection `{}` successfully!", collectionName);
@@ -954,7 +999,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response getCollectionStats(String collectionName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -963,7 +1008,7 @@ public class MilvusGrpcClient implements MilvusClient {
     io.milvus.grpc.CollectionInfo response;
 
     try {
-      response = blockingStub.showCollectionInfo(request);
+      response = blockingStub().showCollectionInfo(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("getCollectionStats for `{}` returned successfully!", collectionName);
@@ -985,7 +1030,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetEntityByIDResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList(), null);
@@ -996,7 +1041,7 @@ public class MilvusGrpcClient implements MilvusClient {
     VectorsData response;
 
     try {
-      response = blockingStub.getVectorsByID(request);
+      response = blockingStub().getVectorsByID(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
@@ -1032,7 +1077,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public ListIDInSegmentResponse listIDInSegment(String collectionName, String segmentName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListIDInSegmentResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -1046,7 +1091,7 @@ public class MilvusGrpcClient implements MilvusClient {
     VectorIds response;
 
     try {
-      response = blockingStub.getVectorIDs(request);
+      response = blockingStub().getVectorIDs(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
@@ -1077,7 +1122,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response deleteEntityByID(String collectionName, List<Long> ids) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1087,7 +1132,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.deleteByID(request);
+      response = blockingStub().deleteByID(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("deleteEntityByID in collection `{}` completed successfully!", collectionName);
@@ -1106,7 +1151,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response flush(List<String> collectionNames) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1115,7 +1160,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.flush(request);
+      response = blockingStub().flush(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Flushed collection {} successfully!", collectionNames);
@@ -1134,7 +1179,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> flushAsync(@Nonnull List<String> collectionNames) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -1143,7 +1188,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.flush(request);
+    response = futureStub().flush(request);
 
     Futures.addCallback(
         response,
@@ -1192,7 +1237,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response compact(String collectionName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1201,7 +1246,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.compact(request);
+      response = blockingStub().compact(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Compacted collection `{}` successfully!", collectionName);
@@ -1220,7 +1265,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> compactAsync(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -1229,7 +1274,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.compact(request);
+    response = futureStub().compact(request);
 
     Futures.addCallback(
         response,
@@ -1322,16 +1367,6 @@ public class MilvusGrpcClient implements MilvusClient {
     return searchResponse;
   }
 
-  private boolean channelIsReadyOrIdle() {
-    if (channel == null) {
-      return false;
-    }
-    ConnectivityState connectivityState = channel.getState(false);
-    return connectivityState == ConnectivityState.READY
-        || connectivityState
-            == ConnectivityState.IDLE; // Since a new RPC would take the channel out of idle mode
-  }
-
   ///////////////////// Log Functions//////////////////////
 
   private void logInfo(String msg, Object... params) {

+ 7 - 2
src/main/java/io/milvus/client/exception/InitializationException.java

@@ -1,11 +1,16 @@
 package io.milvus.client.exception;
 
-public class InitializationFailedException extends MilvusException {
+public class InitializationException extends MilvusException {
   private String host;
   private Throwable cause;
 
-  public InitializationFailedException(String host, Throwable cause) {
+  public InitializationException(String host, Throwable cause) {
     super(false, cause);
     this.host = host;
   }
+
+  public InitializationException(String host, String message) {
+    super(false, message);
+    this.host = host;
+  }
 }

+ 23 - 1
src/main/java/io/milvus/client/exception/MilvusException.java

@@ -1,2 +1,24 @@
-package io.milvus.client.exception;public class MilvusException {
+package io.milvus.client.exception;
+
+public class MilvusException extends RuntimeException {
+  private boolean fillInStackTrace;
+
+  MilvusException(boolean fillInStackTrace) {
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  MilvusException(boolean fillInStackTrace, Throwable cause) {
+    super(cause);
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  MilvusException(boolean fillInStackTrace, String message) {
+    super(message);
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  @Override
+  public synchronized Throwable fillInStackTrace() {
+    return fillInStackTrace ? super.fillInStackTrace() : this;
+  }
 }

+ 20 - 2
src/main/java/io/milvus/client/exception/UnsupportedServerVersion.java

@@ -1,4 +1,22 @@
-package io.milvus.client.exception.UnsupportedServerVersion;
+package io.milvus.client.exception;
 
-public class UnsupportedServerVersion {
+import io.milvus.client.MilvusClient;
+
+public class UnsupportedServerVersion extends MilvusException {
+  private String host;
+  private String expect;
+  private String actual;
+
+  public UnsupportedServerVersion(String host, String expect, String actual) {
+    super(false);
+    this.host = host;
+    this.expect = expect;
+    this.actual = actual;
+  }
+
+  @Override
+  public String getMessage() {
+    return String.format("%s: Milvus client %s is expected to work with Milvus server %s, but the version of the connected server is %s",
+        host, MilvusClient.clientVersion, expect, actual);
+  }
 }

+ 30 - 7
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -20,6 +20,8 @@
 package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
+import io.milvus.client.exception.InitializationException;
+import io.milvus.client.exception.UnsupportedServerVersion;
 import org.apache.commons.text.RandomStringGenerator;
 import org.json.*;
 import org.testcontainers.containers.GenericContainer;
@@ -102,9 +104,8 @@ class MilvusClientTest {
 
   @org.junit.jupiter.api.BeforeEach
   void setUp() throws Exception {
-    client = new MilvusGrpcClient();
     ConnectParam connectParam = connectParamBuilder().build();
-    client.connect(connectParam);
+    client = new MilvusGrpcClient(connectParam);
 
     generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
     randomCollectionName = generator.generate(10);
@@ -122,16 +123,15 @@ class MilvusClientTest {
   @org.junit.jupiter.api.AfterEach
   void tearDown() throws InterruptedException {
     assertTrue(client.dropCollection(randomCollectionName).ok());
-    client.disconnect();
+    client.close();
   }
 
   @org.junit.jupiter.api.Test
   void idleTest() throws InterruptedException, ConnectFailedException {
-    MilvusClient client = new MilvusGrpcClient();
     ConnectParam connectParam = connectParamBuilder()
         .withIdleTimeout(1, TimeUnit.SECONDS)
         .build();
-    client.connect(connectParam);
+    MilvusClient client = new MilvusGrpcClient(connectParam);
     TimeUnit.SECONDS.sleep(2);
     // A new RPC would take the channel out of idle mode
     assertTrue(client.listCollections().ok());
@@ -172,9 +172,32 @@ class MilvusClientTest {
 
   @org.junit.jupiter.api.Test
   void connectUnreachableHost() {
-    MilvusClient client = new MilvusGrpcClient();
     ConnectParam connectParam = connectParamBuilder("250.250.250.250", 19530).build();
-    assertThrows(ConnectFailedException.class, () -> client.connect(connectParam));
+    assertThrows(InitializationException.class, () -> new MilvusGrpcClient(connectParam));
+  }
+
+  @org.junit.jupiter.api.Test
+  void unsupportedServerVersion() {
+    GenericContainer unsupportedMilvusContainer =
+        new GenericContainer("milvusdb/milvus:0.9.1-cpu-d052920-e04ed5")
+            .withExposedPorts(19530);
+    try {
+      unsupportedMilvusContainer.start();
+      ConnectParam connectParam = connectParamBuilder(unsupportedMilvusContainer).build();
+      assertThrows(UnsupportedServerVersion.class, () -> new MilvusGrpcClient(connectParam));
+    } finally {
+      unsupportedMilvusContainer.stop();
+    }
+  }
+
+  @org.junit.jupiter.api.Test
+  void grpcTimeout() {
+    insert();
+    MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
+    Response response = timeoutClient.createIndex(
+        new Index.Builder(randomCollectionName, IndexType.FLAT)
+            .withParamsInJson("{\"nlist\": 16384}").build());
+    assertEquals(Response.Status.RPC_ERROR, response.getStatus());
   }
 
   @org.junit.jupiter.api.Test