Преглед изворни кода

Adding Exception handling for MilvusClient V2 (#1671)

* Adding Exception handling for validating hostname, port and cert in  MilvusClientV2

Signed-off-by: Divya <DIVYA2@ibm.com>

* Addressing review comments

Signed-off-by: Divya <DIVYA2@ibm.com>

* formatting correction

Signed-off-by: Divya <DIVYA2@ibm.com>

* Addressing review comments

Signed-off-by: Divya <DIVYA2@ibm.com>

* fixing compilation error

Signed-off-by: Divya <DIVYA2@ibm.com>

* Corrected validateCert implementation

Signed-off-by: Divya <DIVYA2@ibm.com>

---------

Signed-off-by: Divya <DIVYA2@ibm.com>
Co-authored-by: Divya <DIVYA2@ibm.com>
Divya пре 2 месеци
родитељ
комит
50958beaf6

+ 16 - 0
sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java

@@ -48,6 +48,7 @@ public class ConnectConfig {
     private String proxyAddress;
     private Boolean secure = false;
     private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
+    private boolean enablePrecheck = false;  // default value is false
 
     private SSLContext sslContext;
     // clientRequestId maintains a map for different threads, each thread can assign a specific id.
@@ -79,6 +80,7 @@ public class ConnectConfig {
         this.idleTimeoutMs = builder.idleTimeoutMs;
         this.sslContext = builder.sslContext;
         this.clientRequestId = builder.clientRequestId;
+        this.enablePrecheck = builder.enablePrecheck;
     }
 
     public static ConnectConfigBuilder builder() {
@@ -162,6 +164,9 @@ public class ConnectConfig {
         return proxyAddress;
     }
 
+    public boolean isEnablePrecheck() {
+        return enablePrecheck;
+    }
     // Setters
     public void setUri(String uri) {
         if (uri == null) {
@@ -234,6 +239,10 @@ public class ConnectConfig {
         this.secure = secure;
     }
 
+    public void setEnablePrecheck(boolean enablePrecheck) {
+        this.enablePrecheck = enablePrecheck;
+    }
+
     public void setIdleTimeoutMs(long idleTimeoutMs) {
         this.idleTimeoutMs = idleTimeoutMs;
     }
@@ -301,6 +310,7 @@ public class ConnectConfig {
                 ", serverName='" + serverName + '\'' +
                 ", proxyAddress='" + proxyAddress + '\'' +
                 ", secure=" + secure +
+                ", enablePrecheck=" + enablePrecheck +
                 ", idleTimeoutMs=" + idleTimeoutMs +
                 ", sslContext=" + sslContext +
                 ", clientRequestId=" + clientRequestId +
@@ -328,6 +338,7 @@ public class ConnectConfig {
         private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
         private SSLContext sslContext;
         private ThreadLocal<String> clientRequestId;
+        private boolean enablePrecheck = false;
 
         public ConnectConfigBuilder uri(String uri) {
             if (uri == null) {
@@ -417,6 +428,11 @@ public class ConnectConfig {
             return this;
         }
 
+        public ConnectConfigBuilder enablePrecheck(boolean enablePrecheck) {
+            this.enablePrecheck = enablePrecheck;
+            return this;
+        }
+
         public ConnectConfigBuilder idleTimeoutMs(long idleTimeoutMs) {
             this.idleTimeoutMs = idleTimeoutMs;
             return this;

+ 135 - 0
sdk-core/src/main/java/io/milvus/v2/client/MilvusClientV2.java

@@ -27,6 +27,8 @@ import io.milvus.grpc.MilvusServiceGrpc;
 import io.milvus.orm.iterator.QueryIterator;
 import io.milvus.orm.iterator.SearchIterator;
 import io.milvus.orm.iterator.SearchIteratorV2;
+import io.milvus.v2.exception.ErrorCode;
+import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.cdc.CDCService;
 import io.milvus.v2.service.cdc.request.UpdateReplicateConfigurationReq;
 import io.milvus.v2.service.cdc.response.UpdateReplicateConfigurationResp;
@@ -69,6 +71,19 @@ import org.slf4j.LoggerFactory;
 
 import java.util.List;
 import java.util.concurrent.TimeUnit;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.net.Socket;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import javax.net.ssl.*;
+import java.io.FileInputStream;
+import java.io.InputStream;
+import java.security.KeyStore;
+import java.security.PrivateKey;
+import java.security.SecureRandom;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
 
 public class MilvusClientV2 {
     private static final Logger logger = LoggerFactory.getLogger(MilvusClientV2.class);
@@ -124,6 +139,11 @@ public class MilvusClientV2 {
      */
     private void connect(ConnectConfig connectConfig) {
         this.connectConfig = connectConfig;
+        if (connectConfig.isEnablePrecheck()) {
+            validateHostname(connectConfig);
+            validatePort(connectConfig);
+            validateCert(connectConfig);
+        }
         try {
             if (this.channel != null) {
                 // close channel first
@@ -215,6 +235,121 @@ public class MilvusClientV2 {
         return dbName;
     }
 
+    /**
+     * Validates that the hostname can be resolved before attempting connection.
+     * This provides early failure with clear error messages for DNS issues.
+     *
+     * @param connectConfig Connection configuration containing the host to validate
+     * @throws MilvusClientException if hostname cannot be resolved
+     */
+    public void validateHostname(ConnectConfig connectConfig) {
+        String host = connectConfig.getHost();
+        
+        if (StringUtils.isEmpty(host)) {
+            throw new MilvusClientException(ErrorCode.INVALID_PARAMS, 
+                "Hostname cannot be null or empty");
+        }
+        
+        try {
+            // Attempt DNS resolution
+            InetAddress.getByName(host);
+            logger.debug("Successfully resolved hostname: {}", host);
+        } catch (UnknownHostException e) {
+            String message = String.format(
+                "Failed to resolve hostname '%s'. Please verify the hostname is correct and DNS is configured properly.",
+                host
+            );
+            logger.error(message, e);
+            throw new MilvusClientException(ErrorCode.RPC_ERROR, message);
+        }
+    }
+
+    /**
+     * Validates port number and tests connectivity.
+     *
+     * @param connectConfig Connection configuration containing the port to validate
+     * @throws MilvusClientException if port is invalid or unreachable
+     */
+    public void validatePort(ConnectConfig connectConfig) {
+        int port = connectConfig.getPort();
+        String host = connectConfig.getHost();
+        
+        // Check valid range
+        if (port < 1 || port > 65535) {
+            String message = String.format(
+                "Invalid port number '%d'. Port must be between 1 and 65535.",
+                port
+            );
+            logger.error(message);
+            throw new MilvusClientException(ErrorCode.INVALID_PARAMS, message);
+        }
+        
+        // Test if port is reachable
+        try (Socket socket = new Socket()) {
+            socket.connect(new InetSocketAddress(host, port), (int) connectConfig.getConnectTimeoutMs());
+            logger.debug("Successfully validated port: {}", port);
+        } catch (IOException e) {
+            String message = String.format(
+                "Cannot connect to '%s:%d'. Please verify the port number is correct and server is running.",
+                host, port
+            );
+            logger.error(message, e);
+            throw new MilvusClientException(ErrorCode.RPC_ERROR, message);
+        }
+    }
+    
+    /**
+     * Validates SSL connection with certificates.
+     *
+     * @param connectConfig Connection configuration
+     * @throws MilvusClientException if SSL connection fails
+     */
+    public void validateCert(ConnectConfig connectConfig) {
+        if (!connectConfig.isSecure()) {
+            return;
+        }
+        
+        try {
+            SSLContext sslContext = SSLContext.getInstance("TLS");
+            TrustManagerFactory tmf = null;
+            
+            // Load server certificate (CA cert)
+            if (connectConfig.getServerPemPath() != null && !connectConfig.getServerPemPath().isEmpty()) {
+                try (InputStream certStream = new FileInputStream(connectConfig.getServerPemPath())) {
+                    CertificateFactory cf = CertificateFactory.getInstance("X.509");
+                    X509Certificate caCert = (X509Certificate) cf.generateCertificate(certStream);
+                    
+                    KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
+                    trustStore.load(null, null);
+                    trustStore.setCertificateEntry("ca-cert", caCert);
+                    
+                    tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+                    tmf.init(trustStore);
+                }
+            }
+            
+            // Initialize SSLContext with the server certificate
+            sslContext.init(null, tmf != null ? tmf.getTrustManagers() : null, new SecureRandom());
+            
+            // Validate connection
+            SSLSocketFactory socketFactory = sslContext.getSocketFactory();
+            try (SSLSocket socket = (SSLSocket) socketFactory.createSocket()) {
+                socket.connect(new InetSocketAddress(connectConfig.getHost(), connectConfig.getPort()), 
+                            (int) connectConfig.getConnectTimeoutMs());
+                socket.startHandshake();
+                logger.debug("SSL certificate validation passed");
+            }
+            
+        } catch (SSLException e) {
+            throw new MilvusClientException(ErrorCode.RPC_ERROR, 
+                "SSL certificate validation failed: " + e.getMessage() + 
+                ". Please verify your certificates are correct.");
+        } catch (Exception e) {
+            throw new MilvusClientException(ErrorCode.RPC_ERROR, 
+                "Failed to connect with SSL: " + e.getMessage());
+        }
+    }
+
     /////////////////////////////////////////////////////////////////////////////////////////////
     // Database Operations
     /////////////////////////////////////////////////////////////////////////////////////////////