Pārlūkot izejas kodu

Merge pull request #128 from thirstycrow/grpc-timeout

GRPC timeout support
Xiaohai Xu 4 gadi atpakaļ
vecāks
revīzija
a4e8aab4a8

+ 2 - 2
examples/pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java-examples</artifactId>
-    <version>0.8.3</version>
+    <version>0.8.5-SNAPSHOT</version>
     <build>
         <plugins>
             <plugin>
@@ -63,7 +63,7 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>0.8.4</version>
+            <version>0.8.5-SNAPSHOT</version>
         </dependency>
         <dependency>
             <groupId>com.google.code.gson</groupId>

+ 3 - 17
examples/src/main/java/MilvusClientExample.java

@@ -56,7 +56,7 @@ public class MilvusClientExample {
     return vector;
   }
 
-  public static void main(String[] args) throws InterruptedException, ConnectFailedException {
+  public static void main(String[] args) throws InterruptedException {
 
     // You may need to change the following to the host and port of your Milvus server
     String host = "localhost";
@@ -66,17 +66,8 @@ public class MilvusClientExample {
       port = Integer.parseInt(args[1]);
     }
 
-    // Create Milvus client
-    MilvusClient client = new MilvusGrpcClient();
-
-    // Connect to Milvus server
     ConnectParam connectParam = new ConnectParam.Builder().withHost(host).withPort(port).build();
-    try {
-      Response connectResponse = client.connect(connectParam);
-    } catch (ConnectFailedException e) {
-      System.out.println("Failed to connect to Milvus server: " + e.toString());
-      throw e;
-    }
+    MilvusClient client = new MilvusGrpcClient(connectParam);
 
     // Create a collection with the following collection mapping
     final String collectionName = "example"; // collection name
@@ -217,11 +208,6 @@ public class MilvusClientExample {
     Response dropCollectionResponse = client.dropCollection(collectionName);
 
     // Disconnect from Milvus server
-    try {
-      Response disconnectResponse = client.disconnect();
-    } catch (InterruptedException e) {
-      System.out.println("Failed to disconnect: " + e.toString());
-      throw e;
-    }
+    client.close();
   }
 }

+ 40 - 3
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>0.8.4</version>
+    <version>0.8.5-SNAPSHOT</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>
@@ -135,7 +135,7 @@
         <dependency>
             <groupId>org.junit.jupiter</groupId>
             <artifactId>junit-jupiter</artifactId>
-            <version>5.5.2</version>
+            <version>5.6.2</version>
             <scope>test</scope>
         </dependency>
         <dependency>
@@ -163,7 +163,18 @@
             <artifactId>log4j-slf4j-impl</artifactId>
             <version>2.12.1</version>
         </dependency>
-
+        <dependency>
+            <groupId>org.testcontainers</groupId>
+            <artifactId>testcontainers</artifactId>
+            <version>1.14.3</version>
+            <scope>test</scope>
+        </dependency>
+        <dependency>
+            <groupId>org.testcontainers</groupId>
+            <artifactId>junit-jupiter</artifactId>
+            <version>1.14.3</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 
     <profiles>
@@ -220,6 +231,12 @@
     </profiles>
 
     <build>
+        <resources>
+            <resource>
+                <directory>src/main/resources</directory>
+                <filtering>true</filtering>
+            </resource>
+        </resources>
         <extensions>
             <extension>
                 <groupId>kr.motd.maven</groupId>
@@ -288,6 +305,26 @@
                     </execution>
                 </executions>
             </plugin>
+            <!-- JUnit5 tests are not running with maven 3.6.x
+            https://dzone.com/articles/why-your-junit-5-tests-are-not-running-under-maven
+            -->
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-surefire-plugin</artifactId>
+                <version>2.19.1</version>
+                <dependencies>
+                    <dependency>
+                        <groupId>org.junit.platform</groupId>
+                        <artifactId>junit-platform-surefire-provider</artifactId>
+                        <version>1.1.0</version>
+                    </dependency>
+                    <dependency>
+                        <groupId>org.junit.jupiter</groupId>
+                        <artifactId>junit-jupiter-engine</artifactId>
+                        <version>5.1.0</version>
+                    </dependency>
+                </dependencies>
+            </plugin>
         </plugins>
     </build>
 

+ 36 - 33
src/main/java/io/milvus/client/MilvusClient.java

@@ -20,53 +20,56 @@
 package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+
+import java.io.IOException;
+import java.io.InputStream;
 import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
 
 /** The Milvus Client Interface */
 public interface MilvusClient {
 
-  String clientVersion = "0.8.4";
-
-  /** @return current Milvus client version: 0.8.4 */
+  String clientVersion = new Supplier<String>() {
+    @Override
+    public String get() {
+      Properties properties = new Properties();
+      InputStream inputStream = MilvusClient.class
+          .getClassLoader().getResourceAsStream("milvus-client.properties");
+      try {
+        properties.load(inputStream);
+      } catch (IOException ex) {
+        ExceptionUtils.wrapAndThrow(ex);
+      } finally {
+        try {
+          inputStream.close();
+        } catch (IOException ex) {
+        }
+      }
+      return properties.getProperty("version");
+    }
+  }.get();
+
+  /** @return current Milvus client version */
   default String getClientVersion() {
     return clientVersion;
   }
 
   /**
-   * Connects to Milvus server
-   *
-   * @param connectParam the <code>ConnectParam</code> object
-   *     <pre>
-   * example usage:
-   * <code>
-   * ConnectParam connectParam = new ConnectParam.Builder()
-   *                                             .withHost("localhost")
-   *                                             .withPort(19530)
-   *                                             .withConnectTimeout(10, TimeUnit.SECONDS)
-   *                                             .withKeepAliveTime(Long.MAX_VALUE, TimeUnit.NANOSECONDS)
-   *                                             .withKeepAliveTimeout(20, TimeUnit.SECONDS)
-   *                                             .keepAliveWithoutCalls(false)
-   *                                             .withIdleTimeout(24, TimeUnit.HOURS)
-   *                                             .build();
-   * </code>
-   * </pre>
-   *
-   * @return <code>Response</code>
-   * @throws ConnectFailedException if client failed to connect
-   * @see ConnectParam
-   * @see Response
-   * @see ConnectFailedException
+   * Close this MilvusClient. Wait at most 1 minute for graceful shutdown.
    */
-  Response connect(ConnectParam connectParam) throws ConnectFailedException;
+  default void close() {
+    close(TimeUnit.MINUTES.toSeconds(1));
+  }
 
   /**
-   * Disconnects from Milvus server
-   *
-   * @return <code>Response</code>
-   * @throws InterruptedException
-   * @see Response
+   * Close this MilvusClient. Wait at most `maxWaitSeconds` for graceful shutdown.
    */
-  Response disconnect() throws InterruptedException;
+  void close(long maxWaitSeconds);
+
+  MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit);
 
   /**
    * Creates collection specified by <code>collectionMapping</code>

+ 185 - 142
src/main/java/io/milvus/client/MilvusGrpcClient.java

@@ -24,10 +24,16 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.protobuf.ByteString;
-import io.grpc.ConnectivityState;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.MethodDescriptor;
 import io.grpc.StatusRuntimeException;
+import io.milvus.client.exception.InitializationException;
+import io.milvus.client.exception.UnsupportedServerVersion;
 import io.milvus.grpc.*;
 import java.nio.Buffer;
 import java.nio.ByteBuffer;
@@ -40,106 +46,153 @@ import javax.annotation.Nonnull;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-/** Actual implementation of interface <code>MilvusClient</code> */
-public class MilvusGrpcClient implements MilvusClient {
+public class MilvusGrpcClient extends AbstractMilvusGrpcClient {
+  private static String SUPPORTED_SERVER_VERSION = "0.10";
+  private final ManagedChannel channel;
+  private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
+  private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
+
+  public MilvusGrpcClient(ConnectParam connectParam) {
+    channel = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
+        .usePlaintext()
+        .maxInboundMessageSize(Integer.MAX_VALUE)
+        .keepAliveTime(connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .keepAliveTimeout(connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
+        .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
+        .build();
+    blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
+    futureStub = MilvusServiceGrpc.newFutureStub(channel);
 
-  private static final Logger logger = LoggerFactory.getLogger(MilvusGrpcClient.class);
-  private final String extraParamKey = "params";
-  private ManagedChannel channel = null;
-  private MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub = null;
-  private MilvusServiceGrpc.MilvusServiceFutureStub futureStub = null;
+    try {
+      Response response = getServerVersion();
+      if (response.ok()) {
+        String serverVersion = getServerVersion().getMessage();
+        if (!serverVersion.matches("^" + SUPPORTED_SERVER_VERSION + "(\\..*)?$")) {
+          throw new UnsupportedServerVersion(connectParam.getHost(), SUPPORTED_SERVER_VERSION, serverVersion);
+        }
+      } else {
+        throw new InitializationException(connectParam.getHost(), response.getMessage());
+      }
+    } catch (Throwable t) {
+      channel.shutdownNow();
+      throw t;
+    }
+  }
 
-  ////////////////////// Constructor //////////////////////
-  public MilvusGrpcClient() {}
+  @Override
+  protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
+    return blockingStub;
+  }
 
-  /////////////////////// Client Calls///////////////////////
+  @Override
+  protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
+    return futureStub;
+  }
 
   @Override
-  public Response connect(ConnectParam connectParam) throws ConnectFailedException {
-    if (channel != null && !(channel.isShutdown() || channel.isTerminated())) {
-      logWarning("Channel is not shutdown or terminated");
-      throw new ConnectFailedException("Channel is not shutdown or terminated");
+  protected boolean maybeAvailable() {
+    switch (channel.getState(false)) {
+      case IDLE:
+      case CONNECTING:
+      case READY:
+        return true;
+      default:
+        return false;
     }
+  }
 
+  @Override
+  public void close(long maxWaitSeconds) {
+    channel.shutdown();
+    long now = System.nanoTime();
+    long deadline = now + TimeUnit.SECONDS.toNanos(maxWaitSeconds);
+    boolean interrupted = false;
     try {
-
-      channel =
-          ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
-              .usePlaintext()
-              .maxInboundMessageSize(Integer.MAX_VALUE)
-              .keepAliveTime(
-                  connectParam.getKeepAliveTime(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .keepAliveTimeout(
-                  connectParam.getKeepAliveTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
-              .idleTimeout(connectParam.getIdleTimeout(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS)
-              .build();
-
-      channel.getState(true);
-
-      long timeout = connectParam.getConnectTimeout(TimeUnit.MILLISECONDS);
-      logInfo("Trying to connect...Timeout in {} ms", timeout);
-
-      final long checkFrequency = 100; // ms
-      while (channel.getState(false) != ConnectivityState.READY) {
-        if (timeout <= 0) {
-          logError("Connect timeout!");
-          throw new ConnectFailedException("Connect timeout");
+      while (now < deadline && !channel.isTerminated()) {
+        try {
+          channel.awaitTermination(deadline - now, TimeUnit.NANOSECONDS);
+        } catch (InterruptedException ex) {
+          interrupted = true;
         }
-        TimeUnit.MILLISECONDS.sleep(checkFrequency);
-        timeout -= checkFrequency;
       }
+      if (!channel.isTerminated()) {
+        channel.shutdownNow();
+      }
+    } finally {
+      if (interrupted) {
+        Thread.currentThread().interrupt();
+      }
+    }
+  }
 
-      blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
-      futureStub = MilvusServiceGrpc.newFutureStub(channel);
+  public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
+    final long timeoutMillis = timeoutUnit.toMillis(timeout);
+    final TimeoutInterceptor timeoutInterceptor = new TimeoutInterceptor(timeoutMillis);
+    final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub =
+        this.blockingStub.withInterceptors(timeoutInterceptor);
+    final MilvusServiceGrpc.MilvusServiceFutureStub futureStub =
+        this.futureStub.withInterceptors(timeoutInterceptor);
 
-      // check server version
-      String serverVersion = getServerVersion().getMessage();
-      if (!serverVersion.contains("0.10.")) {
-        logError(
-            "Connect failed! Server version {} does not match SDK version 0.8.4", serverVersion);
-        throw new ConnectFailedException("Failed to connect to Milvus server.");
+    return new AbstractMilvusGrpcClient() {
+
+      @Override
+      protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub() {
+        return blockingStub;
       }
 
-    } catch (Exception e) {
-      if (!(e instanceof ConnectFailedException)) {
-        logError("Connect failed! {}", e.toString());
+      @Override
+      protected MilvusServiceGrpc.MilvusServiceFutureStub futureStub() {
+        return futureStub;
       }
-      throw new ConnectFailedException("Exception occurred: " + e.toString());
-    }
 
-    logInfo(
-        "Connection established successfully to host={}, port={}",
-        connectParam.getHost(),
-        String.valueOf(connectParam.getPort()));
-    return new Response(Response.Status.SUCCESS);
-  }
+      @Override
+      protected boolean maybeAvailable() {
+        return MilvusGrpcClient.this.maybeAvailable();
+      }
 
-  @Override
-  public Response disconnect() throws InterruptedException {
-    if (!channelIsReadyOrIdle()) {
-      logWarning("You are not connected to Milvus server");
-      return new Response(Response.Status.CLIENT_NOT_CONNECTED);
-    } else {
-      try {
-        if (channel.shutdown().awaitTermination(60, TimeUnit.SECONDS)) {
-          logInfo("Channel terminated");
-        } else {
-          logError("Encountered error when terminating channel");
-          return new Response(Response.Status.RPC_ERROR);
-        }
-      } catch (InterruptedException e) {
-        logError("Exception thrown when terminating channel: {}", e.toString());
-        throw e;
+      @Override
+      public void close(long maxWaitSeconds) {
+        MilvusGrpcClient.this.close(maxWaitSeconds);
       }
+
+      @Override
+      public MilvusClient withTimeout(long timeout, TimeUnit timeoutUnit) {
+        return MilvusGrpcClient.this.withTimeout(timeout, timeoutUnit);
+      }
+    };
+  }
+
+  private static class TimeoutInterceptor implements ClientInterceptor {
+    private long timeoutMillis;
+
+    TimeoutInterceptor(long timeoutMillis) {
+      this.timeoutMillis = timeoutMillis;
+    }
+
+    @Override
+    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+      return next.newCall(method, callOptions.withDeadlineAfter(timeoutMillis, TimeUnit.MILLISECONDS));
     }
-    return new Response(Response.Status.SUCCESS);
   }
+}
+
+abstract class AbstractMilvusGrpcClient implements MilvusClient {
+  private static final Logger logger = LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
+
+  private final String extraParamKey = "params";
+
+  protected abstract MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub();
+
+  protected abstract MilvusServiceGrpc.MilvusServiceFutureStub futureStub();
+
+  protected abstract boolean maybeAvailable();
 
   @Override
   public Response createCollection(@Nonnull CollectionMapping collectionMapping) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -155,7 +208,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createCollection(request);
+      response = blockingStub().createCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created collection successfully!\n{}", collectionMapping.toString());
@@ -179,7 +232,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasCollectionResponse hasCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new HasCollectionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -188,7 +241,7 @@ public class MilvusGrpcClient implements MilvusClient {
     BoolReply response;
 
     try {
-      response = blockingStub.hasCollection(request);
+      response = blockingStub().hasCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("hasCollection `{}` = {}", collectionName, response.getBoolReply());
@@ -212,7 +265,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -221,7 +274,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropCollection(request);
+      response = blockingStub().dropCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped collection `{}` successfully!", collectionName);
@@ -240,7 +293,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createIndex(@Nonnull Index index) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -257,7 +310,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createIndex(request);
+      response = blockingStub().createIndex(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created index successfully!\n{}", index.toString());
@@ -276,7 +329,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> createIndexAsync(@Nonnull Index index) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -292,7 +345,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.createIndex(request);
+    response = futureStub().createIndex(request);
 
     Futures.addCallback(
         response,
@@ -320,7 +373,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response createPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -331,7 +384,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.createPartition(request);
+      response = blockingStub().createPartition(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Created partition `{}` in collection `{}` successfully!", tag, collectionName);
@@ -354,7 +407,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public HasPartitionResponse hasPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new HasPartitionResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), false);
     }
@@ -364,7 +417,7 @@ public class MilvusGrpcClient implements MilvusClient {
     BoolReply response;
 
     try {
-      response = blockingStub.hasPartition(request);
+      response = blockingStub().hasPartition(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -395,7 +448,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListPartitionsResponse listPartitions(String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListPartitionsResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -405,7 +458,7 @@ public class MilvusGrpcClient implements MilvusClient {
     PartitionList response;
 
     try {
-      response = blockingStub.showPartitions(request);
+      response = blockingStub().showPartitions(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -432,7 +485,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropPartition(String collectionName, String tag) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -442,7 +495,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropPartition(request);
+      response = blockingStub().dropPartition(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped partition `{}` in collection `{}` successfully!", tag, collectionName);
@@ -465,7 +518,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public InsertResponse insert(@Nonnull InsertParam insertParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new InsertResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -484,7 +537,7 @@ public class MilvusGrpcClient implements MilvusClient {
     VectorIds response;
 
     try {
-      response = blockingStub.insert(request);
+      response = blockingStub().insert(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo(
@@ -511,7 +564,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<InsertResponse> insertAsync(@Nonnull InsertParam insertParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(
           new InsertResponse(
@@ -531,7 +584,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<VectorIds> response;
 
-    response = futureStub.insert(request);
+    response = futureStub().insert(request);
 
     Futures.addCallback(
         response,
@@ -575,7 +628,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public SearchResponse search(@Nonnull SearchParam searchParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
@@ -603,7 +656,7 @@ public class MilvusGrpcClient implements MilvusClient {
     TopKQueryResult response;
 
     try {
-      response = blockingStub.search(request);
+      response = blockingStub().search(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         SearchResponse searchResponse = buildSearchResponse(response);
@@ -632,7 +685,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<SearchResponse> searchAsync(@Nonnull SearchParam searchParam) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       SearchResponse searchResponse = new SearchResponse();
       searchResponse.setResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED));
@@ -659,7 +712,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<TopKQueryResult> response;
 
-    response = futureStub.search(request);
+    response = futureStub().search(request);
 
     Futures.addCallback(
         response,
@@ -704,7 +757,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public GetCollectionInfoResponse getCollectionInfo(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetCollectionInfoResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
@@ -714,7 +767,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionSchema response;
 
     try {
-      response = blockingStub.describeCollection(request);
+      response = blockingStub().describeCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         CollectionMapping collectionMapping =
@@ -746,7 +799,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListCollectionsResponse listCollections() {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListCollectionsResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -756,7 +809,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionNameList response;
 
     try {
-      response = blockingStub.showCollections(request);
+      response = blockingStub().showCollections(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         List<String> collectionNames = response.getCollectionNamesList();
@@ -780,7 +833,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public CountEntitiesResponse countEntities(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new CountEntitiesResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), 0);
     }
@@ -789,7 +842,7 @@ public class MilvusGrpcClient implements MilvusClient {
     CollectionRowCount response;
 
     try {
-      response = blockingStub.countCollection(request);
+      response = blockingStub().countCollection(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         long collectionRowCount = response.getCollectionRowCount();
@@ -824,7 +877,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   public Response command(@Nonnull String command) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -833,7 +886,7 @@ public class MilvusGrpcClient implements MilvusClient {
     StringReply response;
 
     try {
-      response = blockingStub.cmd(request);
+      response = blockingStub().cmd(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Command `{}`: {}", command, response.getStringReply());
@@ -853,7 +906,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response loadCollection(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -862,7 +915,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.preloadCollection(request);
+      response = blockingStub().preloadCollection(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Loaded collection `{}` successfully!", collectionName);
@@ -881,7 +934,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public GetIndexInfoResponse getIndexInfo(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetIndexInfoResponse(new Response(Response.Status.CLIENT_NOT_CONNECTED), null);
     }
@@ -890,7 +943,7 @@ public class MilvusGrpcClient implements MilvusClient {
     IndexParam response;
 
     try {
-      response = blockingStub.describeIndex(request);
+      response = blockingStub().describeIndex(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         String extraParam = "";
@@ -927,7 +980,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public Response dropIndex(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -936,7 +989,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.dropIndex(request);
+      response = blockingStub().dropIndex(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Dropped index for collection `{}` successfully!", collectionName);
@@ -954,7 +1007,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response getCollectionStats(String collectionName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -963,7 +1016,7 @@ public class MilvusGrpcClient implements MilvusClient {
     io.milvus.grpc.CollectionInfo response;
 
     try {
-      response = blockingStub.showCollectionInfo(request);
+      response = blockingStub().showCollectionInfo(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("getCollectionStats for `{}` returned successfully!", collectionName);
@@ -985,7 +1038,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public GetEntityByIDResponse getEntityByID(String collectionName, List<Long> ids) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new GetEntityByIDResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList(), null);
@@ -996,7 +1049,7 @@ public class MilvusGrpcClient implements MilvusClient {
     VectorsData response;
 
     try {
-      response = blockingStub.getVectorsByID(request);
+      response = blockingStub().getVectorsByID(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
@@ -1032,7 +1085,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public ListIDInSegmentResponse listIDInSegment(String collectionName, String segmentName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new ListIDInSegmentResponse(
           new Response(Response.Status.CLIENT_NOT_CONNECTED), Collections.emptyList());
@@ -1046,7 +1099,7 @@ public class MilvusGrpcClient implements MilvusClient {
     VectorIds response;
 
     try {
-      response = blockingStub.getVectorIDs(request);
+      response = blockingStub().getVectorIDs(request);
 
       if (response.getStatus().getErrorCode() == ErrorCode.SUCCESS) {
 
@@ -1077,7 +1130,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response deleteEntityByID(String collectionName, List<Long> ids) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1087,7 +1140,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.deleteByID(request);
+      response = blockingStub().deleteByID(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("deleteEntityByID in collection `{}` completed successfully!", collectionName);
@@ -1106,7 +1159,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response flush(List<String> collectionNames) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1115,7 +1168,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.flush(request);
+      response = blockingStub().flush(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Flushed collection {} successfully!", collectionNames);
@@ -1134,7 +1187,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> flushAsync(@Nonnull List<String> collectionNames) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -1143,7 +1196,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.flush(request);
+    response = futureStub().flush(request);
 
     Futures.addCallback(
         response,
@@ -1192,7 +1245,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
   @Override
   public Response compact(String collectionName) {
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return new Response(Response.Status.CLIENT_NOT_CONNECTED);
     }
@@ -1201,7 +1254,7 @@ public class MilvusGrpcClient implements MilvusClient {
     Status response;
 
     try {
-      response = blockingStub.compact(request);
+      response = blockingStub().compact(request);
 
       if (response.getErrorCode() == ErrorCode.SUCCESS) {
         logInfo("Compacted collection `{}` successfully!", collectionName);
@@ -1220,7 +1273,7 @@ public class MilvusGrpcClient implements MilvusClient {
   @Override
   public ListenableFuture<Response> compactAsync(@Nonnull String collectionName) {
 
-    if (!channelIsReadyOrIdle()) {
+    if (!maybeAvailable()) {
       logWarning("You are not connected to Milvus server");
       return Futures.immediateFuture(new Response(Response.Status.CLIENT_NOT_CONNECTED));
     }
@@ -1229,7 +1282,7 @@ public class MilvusGrpcClient implements MilvusClient {
 
     ListenableFuture<Status> response;
 
-    response = futureStub.compact(request);
+    response = futureStub().compact(request);
 
     Futures.addCallback(
         response,
@@ -1322,16 +1375,6 @@ public class MilvusGrpcClient implements MilvusClient {
     return searchResponse;
   }
 
-  private boolean channelIsReadyOrIdle() {
-    if (channel == null) {
-      return false;
-    }
-    ConnectivityState connectivityState = channel.getState(false);
-    return connectivityState == ConnectivityState.READY
-        || connectivityState
-            == ConnectivityState.IDLE; // Since a new RPC would take the channel out of idle mode
-  }
-
   ///////////////////// Log Functions//////////////////////
 
   private void logInfo(String msg, Object... params) {

+ 16 - 0
src/main/java/io/milvus/client/exception/InitializationException.java

@@ -0,0 +1,16 @@
+package io.milvus.client.exception;
+
+public class InitializationException extends MilvusException {
+  private String host;
+  private Throwable cause;
+
+  public InitializationException(String host, Throwable cause) {
+    super(false, cause);
+    this.host = host;
+  }
+
+  public InitializationException(String host, String message) {
+    super(false, message);
+    this.host = host;
+  }
+}

+ 24 - 0
src/main/java/io/milvus/client/exception/MilvusException.java

@@ -0,0 +1,24 @@
+package io.milvus.client.exception;
+
+public class MilvusException extends RuntimeException {
+  private boolean fillInStackTrace;
+
+  MilvusException(boolean fillInStackTrace) {
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  MilvusException(boolean fillInStackTrace, Throwable cause) {
+    super(cause);
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  MilvusException(boolean fillInStackTrace, String message) {
+    super(message);
+    this.fillInStackTrace = fillInStackTrace;
+  }
+
+  @Override
+  public synchronized Throwable fillInStackTrace() {
+    return fillInStackTrace ? super.fillInStackTrace() : this;
+  }
+}

+ 22 - 0
src/main/java/io/milvus/client/exception/UnsupportedServerVersion.java

@@ -0,0 +1,22 @@
+package io.milvus.client.exception;
+
+import io.milvus.client.MilvusClient;
+
+public class UnsupportedServerVersion extends MilvusException {
+  private String host;
+  private String expect;
+  private String actual;
+
+  public UnsupportedServerVersion(String host, String expect, String actual) {
+    super(false);
+    this.host = host;
+    this.expect = expect;
+    this.actual = actual;
+  }
+
+  @Override
+  public String getMessage() {
+    return String.format("%s: Milvus client %s is expected to work with Milvus server %s, but the version of the connected server is %s",
+        host, MilvusClient.clientVersion, expect, actual);
+  }
+}

+ 3 - 0
src/main/resources/milvus-client.properties

@@ -0,0 +1,3 @@
+groupId=${project.groupId}
+artifactId=${project.artifactId}
+version=${project.version}

+ 64 - 16
src/test/java/io/milvus/client/MilvusGrpcClientTest.java

@@ -20,8 +20,15 @@
 package io.milvus.client;
 
 import com.google.common.util.concurrent.ListenableFuture;
+import io.milvus.client.exception.InitializationException;
+import io.milvus.client.exception.UnsupportedServerVersion;
 import org.apache.commons.text.RandomStringGenerator;
 import org.json.*;
+import org.junit.jupiter.api.condition.DisabledIfSystemProperty;
+import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
 
 import java.nio.ByteBuffer;
 import java.util.*;
@@ -33,6 +40,32 @@ import java.util.stream.LongStream;
 
 import static org.junit.jupiter.api.Assertions.*;
 
+@Testcontainers
+@EnabledIfSystemProperty(named = "with-containers", matches = "true")
+class ContainerMilvusClientTest extends MilvusClientTest {
+  @Container
+  private static GenericContainer milvusContainer =
+      new GenericContainer("milvusdb/milvus:0.10.1-cpu-d072020-bd02b1")
+          .withExposedPorts(19530);
+
+  @Container
+  private static GenericContainer unsupportedMilvusContainer =
+      new GenericContainer("milvusdb/milvus:0.9.1-cpu-d052920-e04ed5")
+          .withExposedPorts(19530);
+
+  @Override
+  protected ConnectParam.Builder connectParamBuilder() {
+    return connectParamBuilder(milvusContainer);
+  }
+
+  @org.junit.jupiter.api.Test
+  void unsupportedServerVersion() {
+    ConnectParam connectParam = connectParamBuilder(unsupportedMilvusContainer).build();
+    assertThrows(UnsupportedServerVersion.class, () -> new MilvusGrpcClient(connectParam));
+  }
+}
+
+@DisabledIfSystemProperty(named = "with-containers", matches = "true")
 class MilvusClientTest {
 
   private MilvusClient client;
@@ -43,6 +76,18 @@ class MilvusClientTest {
   private int size;
   private int dimension;
 
+  protected ConnectParam.Builder connectParamBuilder() {
+    return connectParamBuilder("localhost", 19530);
+  }
+
+  protected ConnectParam.Builder connectParamBuilder(GenericContainer milvusContainer) {
+    return connectParamBuilder(milvusContainer.getHost(), milvusContainer.getFirstMappedPort());
+  }
+
+  private ConnectParam.Builder connectParamBuilder(String host, int port) {
+    return new ConnectParam.Builder().withHost(host).withPort(port);
+  }
+
   // Helper function that generates random float vectors
   static List<List<Float>> generateFloatVectors(int vectorCount, int dimension) {
     SplittableRandom splittableRandom = new SplittableRandom();
@@ -81,11 +126,8 @@ class MilvusClientTest {
 
   @org.junit.jupiter.api.BeforeEach
   void setUp() throws Exception {
-
-    client = new MilvusGrpcClient();
-    ConnectParam connectParam =
-        new ConnectParam.Builder().withHost("localhost").withPort(19530).build();
-    client.connect(connectParam);
+    ConnectParam connectParam = connectParamBuilder().build();
+    client = new MilvusGrpcClient(connectParam);
 
     generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();
     randomCollectionName = generator.generate(10);
@@ -103,18 +145,15 @@ class MilvusClientTest {
   @org.junit.jupiter.api.AfterEach
   void tearDown() throws InterruptedException {
     assertTrue(client.dropCollection(randomCollectionName).ok());
-    client.disconnect();
+    client.close();
   }
 
   @org.junit.jupiter.api.Test
   void idleTest() throws InterruptedException, ConnectFailedException {
-    MilvusClient client = new MilvusGrpcClient();
-    ConnectParam connectParam =
-        new ConnectParam.Builder()
-            .withHost("localhost")
-            .withIdleTimeout(1, TimeUnit.SECONDS)
-            .build();
-    client.connect(connectParam);
+    ConnectParam connectParam = connectParamBuilder()
+        .withIdleTimeout(1, TimeUnit.SECONDS)
+        .build();
+    MilvusClient client = new MilvusGrpcClient(connectParam);
     TimeUnit.SECONDS.sleep(2);
     // A new RPC would take the channel out of idle mode
     assertTrue(client.listCollections().ok());
@@ -155,9 +194,18 @@ class MilvusClientTest {
 
   @org.junit.jupiter.api.Test
   void connectUnreachableHost() {
-    MilvusClient client = new MilvusGrpcClient();
-    ConnectParam connectParam = new ConnectParam.Builder().withHost("250.250.250.250").build();
-    assertThrows(ConnectFailedException.class, () -> client.connect(connectParam));
+    ConnectParam connectParam = connectParamBuilder("250.250.250.250", 19530).build();
+    assertThrows(InitializationException.class, () -> new MilvusGrpcClient(connectParam));
+  }
+
+  @org.junit.jupiter.api.Test
+  void grpcTimeout() {
+    insert();
+    MilvusClient timeoutClient = client.withTimeout(1, TimeUnit.MILLISECONDS);
+    Response response = timeoutClient.createIndex(
+        new Index.Builder(randomCollectionName, IndexType.FLAT)
+            .withParamsInJson("{\"nlist\": 16384}").build());
+    assertEquals(Response.Status.RPC_ERROR, response.getStatus());
   }
 
   @org.junit.jupiter.api.Test