소스 검색

Connection initialize (#861)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 년 전
부모
커밋
60a1a5eb8b

+ 7 - 1
examples/pom.xml

@@ -83,6 +83,12 @@
             <artifactId>tensorflow-core-platform</artifactId>
             <version>0.5.0</version>
         </dependency>
+        <dependency>
+            <groupId>org.projectlombok</groupId>
+            <artifactId>lombok</artifactId>
+            <version>1.18.22</version>
+            <scope>provided</scope>
+        </dependency>
     </dependencies>
 
-</project>
+</project>

+ 70 - 1
src/main/java/io/milvus/client/MilvusServiceClient.java

@@ -53,6 +53,9 @@ import java.io.File;
 import java.io.IOException;
 import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.time.LocalDateTime;
 
 public class MilvusServiceClient extends AbstractMilvusGrpcClient {
 
@@ -132,12 +135,25 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                 channel = builder.build();
             }
         } catch (IOException e) {
-            logError("Failed to open credentials file, error:{}\n", e.getMessage());
+            String msg = "Failed to open credentials file. Error: " + e.getMessage();
+            logError(msg);
+            throw new RuntimeException(msg);
         }
 
         assert channel != null;
         blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
         futureStub = MilvusServiceGrpc.newFutureStub(channel);
+
+        // calls a RPC Connect() to the remote server, and sends the client info to the server
+        // so that the server knows which client is interacting, especially for accesses log.
+        this.timeoutMs = connectParam.getConnectTimeoutMs(); // set this value to connectTimeoutMs to control the retry()
+        R<ConnectResponse> resp = this.retry(()->connect(connectParam));
+        if (resp.getStatus() != R.Status.Success.getCode()) {
+            String msg = "Failed to initialize connection. Error: " + resp.getMessage();
+            logError(msg);
+            throw new RuntimeException(msg);
+        }
+        this.timeoutMs = 0; // reset the timeout value to default
     }
 
     protected MilvusServiceClient(MilvusServiceClient src) {
@@ -345,6 +361,59 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
         return R.failed(new RuntimeException(msg));
     }
 
+    /**
+     * This method is internal used, it calls a RPC Connect() to the remote server,
+     * and sends the client info to the server so that the server knows which client is interacting,
+     * especially for accesses log.
+     *
+     * The info includes:
+     * 1. username(if Authentication is enabled)
+     * 2. the client computer's name
+     * 3. sdk language type and version
+     * 4. the client's local time
+     */
+    private R<ConnectResponse> connect(@NonNull ConnectParam connectParam) {
+        ClientInfo info = ClientInfo.newBuilder()
+                .setSdkType("Java")
+                .setSdkVersion(getSDKVersion())
+                .setUser(connectParam.getUserName())
+                .setHost(getHostName())
+                .setLocalTime(getLocalTimeStr())
+                .build();
+        ConnectRequest req = ConnectRequest.newBuilder().setClientInfo(info).build();
+        ConnectResponse resp = this.blockingStub.withWaitForReady()
+                .withDeadlineAfter(connectParam.getConnectTimeoutMs(), TimeUnit.MILLISECONDS)
+                .connect(req);
+        if (resp.getStatus().getCode() != 0 || !resp.getStatus().getErrorCode().equals(ErrorCode.Success)) {
+            throw new RuntimeException("Failed to initialize connection. Error: " + resp.getStatus().getReason());
+        }
+        return R.success(resp);
+    }
+
+    private String getHostName() {
+        try {
+            InetAddress address = InetAddress.getLocalHost();
+            return address.getHostName();
+        } catch (UnknownHostException e) {
+            logWarning("Failed to get host name! Exception:{}", e);
+            return "Unknown";
+        }
+    }
+
+    private String getLocalTimeStr() {
+        LocalDateTime now = LocalDateTime.now();
+        return now.toString();
+    }
+
+    private String getSDKVersion() {
+        Package pkg = MilvusServiceClient.class.getPackage();
+        String ver = pkg.getImplementationVersion();
+        if (ver == null) {
+            return "";
+        }
+        return ver;
+    }
+
     @Override
     public void setLogLevel(LogLevel level) {
         logLevel = level;

+ 3 - 2
src/main/java/io/milvus/common/resourcegroup/ResourceGroupConfig.java

@@ -2,6 +2,7 @@ package io.milvus.common.resourcegroup;
 
 import java.util.stream.Collectors;
 import java.util.List;
+import java.util.ArrayList;
 import lombok.NonNull;
 import lombok.Getter;
 
@@ -17,13 +18,13 @@ public class ResourceGroupConfig {
         this.limits = builder.limits;
 
         if (null == builder.from) {
-            this.from = List.of();
+            this.from = new ArrayList<>();
         } else {
             this.from = builder.from;
         }
 
         if (null == builder.to) {
-            this.to = List.of();
+            this.to = new ArrayList<>();
         } else {
             this.to = builder.to;
         }

+ 8 - 0
src/main/java/io/milvus/param/ConnectParam.java

@@ -57,6 +57,7 @@ public class ConnectParam {
     private final String caPemPath;
     private final String serverPemPath;
     private final String serverName;
+    private final String userName;
 
     protected ConnectParam(@NonNull Builder builder) {
         this.host = builder.host;
@@ -77,6 +78,7 @@ public class ConnectParam {
         this.caPemPath = builder.caPemPath;
         this.serverPemPath = builder.serverPemPath;
         this.serverName = builder.serverName;
+        this.userName = builder.userName;
     }
 
     public static Builder newBuilder() {
@@ -109,6 +111,11 @@ public class ConnectParam {
         private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
         private String authorization = Base64.getEncoder().encodeToString("root:milvus".getBytes(StandardCharsets.UTF_8));
 
+        // username/password is encoded into authorization, this member is to keep the origin username for MilvusServiceClient.connect()
+        // The MilvusServiceClient.connect() is to send the client info to the server so that the server knows which client is interacting
+        // If the username is unknown, send it as an empty string.
+        private String userName = "";
+
         protected Builder() {
         }
 
@@ -262,6 +269,7 @@ public class ConnectParam {
          */
         public Builder withAuthorization(String username, String password) {
             this.authorization = Base64.getEncoder().encodeToString(String.format("%s:%s", username, password).getBytes(StandardCharsets.UTF_8));
+            this.userName = username;
             return this;
         }
 

+ 30 - 0
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -279,6 +279,36 @@ class MilvusServiceClientTest {
         );
     }
 
+    @Test
+    void testConnect() {
+        ConnectParam connectParam = ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(testPort)
+                .withConnectTimeout(1000, TimeUnit.MILLISECONDS)
+                .build();
+        RetryParam retryParam = RetryParam.newBuilder()
+                .withMaxRetryTimes(2)
+                .build();
+
+        Exception e = assertThrows(RuntimeException.class, () -> {
+            MilvusClient client = new MilvusServiceClient(connectParam).withRetry(retryParam);
+        });
+        assertTrue(e.getMessage().contains("DEADLINE_EXCEEDED"));
+
+        MockMilvusServer server = startServer();
+        String dbName = "base";
+        String reason = "database not found[database=" + dbName + "]";
+        mockServerImpl.setConnectResponse(ConnectResponse.newBuilder()
+                .setStatus(Status.newBuilder().setCode(800).setReason(reason).build()).build());
+
+        e = assertThrows(RuntimeException.class, () -> {
+            MilvusClient client = new MilvusServiceClient(connectParam).withRetry(retryParam);
+        });
+        assertTrue(e.getMessage().contains(reason));
+
+        server.stop();
+    }
+
     @Test
     void createCollectionParam() {
         // test throw exception with illegal input for FieldType

+ 14 - 0
src/test/java/io/milvus/server/MockMilvusServerImpl.java

@@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory;
 
 public class MockMilvusServerImpl extends MilvusServiceGrpc.MilvusServiceImplBase {
     private static final Logger logger = LoggerFactory.getLogger(MockMilvusServerImpl.class);
+    private io.milvus.grpc.ConnectResponse respConnect;
     private io.milvus.grpc.Status respCreateCollection;
     private io.milvus.grpc.DescribeCollectionResponse respDescribeCollection;
     private io.milvus.grpc.Status respDropCollection;
@@ -82,6 +83,19 @@ public class MockMilvusServerImpl extends MilvusServiceGrpc.MilvusServiceImplBas
     public MockMilvusServerImpl() {
     }
 
+    @Override
+    public void connect(io.milvus.grpc.ConnectRequest request,
+                        io.grpc.stub.StreamObserver<io.milvus.grpc.ConnectResponse> responseObserver) {
+        logger.info("MockServer receive connect() call");
+
+        responseObserver.onNext(respConnect);
+        responseObserver.onCompleted();
+    }
+
+    public void setConnectResponse(io.milvus.grpc.ConnectResponse resp) {
+        respConnect = resp;
+    }
+
     @Override
     public void createCollection(io.milvus.grpc.CreateCollectionRequest request,
                                  io.grpc.stub.StreamObserver<io.milvus.grpc.Status> responseObserver) {