Browse Source

Adding in Proxy setting for connection to milvus. (#1350)

* Adding in Proxy setting for connection to milvus.

Co-authored-by: divyaruhil <divyaruhil999@gmail.com>
Signed-off-by: jeri-jose <jerijose111@gmail.com>

* proxy-setting configuration into reusable method

Co-authored-by: divyaruhil <divyaruhil999@gmail.com>
Signed-off-by: jeri-jose <jerijose111@gmail.com>

---------

Signed-off-by: jeri-jose <jerijose111@gmail.com>
Co-authored-by: divyaruhil <divyaruhil999@gmail.com>
Jeri Jose 3 weeks ago
parent
commit
c57a8df5c1

+ 18 - 2
sdk-core/src/main/java/io/milvus/client/MilvusServiceClient.java

@@ -46,6 +46,9 @@ import io.milvus.param.index.*;
 import io.milvus.param.partition.*;
 import io.milvus.param.resourcegroup.*;
 import io.milvus.param.role.*;
+import io.milvus.v2.utils.ClientUtils;
+import io.grpc.ProxiedSocketAddress;
+import io.grpc.ProxyDetector;
 import lombok.NonNull;
 import org.apache.commons.lang3.StringUtils;
 
@@ -58,6 +61,9 @@ import java.util.concurrent.TimeUnit;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
 import java.time.LocalDateTime;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import io.grpc.HttpConnectProxiedSocketAddress;
 
 public class MilvusServiceClient extends AbstractMilvusGrpcClient {
 
@@ -102,7 +108,6 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                 SslContext sslContext = GrpcSslContexts.forClient()
                         .trustManager(new File(connectParam.getServerPemPath()))
                         .build();
-
                 NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
                         .overrideAuthority(connectParam.getServerName())
                         .sslContext(sslContext)
@@ -112,6 +117,10 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
                         .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(clientInterceptors);
+                // Add proxy configuration if proxy address is set
+                if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
+                    ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
+                }
                 if(connectParam.isSecure()){
                     builder.useTransportSecurity();
                 }
@@ -124,7 +133,6 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .trustManager(new File(connectParam.getCaPemPath()))
                         .keyManager(new File(connectParam.getClientPemPath()), new File(connectParam.getClientKeyPath()))
                         .build();
-
                 NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
                         .sslContext(sslContext)
                         .maxInboundMessageSize(Integer.MAX_VALUE)
@@ -133,6 +141,11 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
                         .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(clientInterceptors);
+                
+                // Add proxy configuration if proxy address is set
+                if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
+                    ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
+                }     
                 if(connectParam.isSecure()){
                     builder.useTransportSecurity();
                 }
@@ -150,6 +163,9 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
                         .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
                         .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(clientInterceptors);
+                if (StringUtils.isNotEmpty(connectParam.getProxyAddress())) {
+                    ClientUtils.configureProxy(builder, connectParam.getProxyAddress());
+                }
                 if(connectParam.isSecure()){
                     builder.useTransportSecurity();
                 }

+ 16 - 1
sdk-core/src/main/java/io/milvus/param/ConnectParam.java

@@ -59,6 +59,7 @@ public class ConnectParam {
     private final String serverName;
     private final String userName;
     private final ThreadLocal<String> clientRequestId;
+    private final String proxyAddress;
 
     protected ConnectParam(@NonNull Builder builder) {
         this.host = builder.host;
@@ -81,6 +82,7 @@ public class ConnectParam {
         this.serverName = builder.serverName;
         this.userName = builder.userName;
         this.clientRequestId = builder.clientRequestId;
+        this.proxyAddress = builder.proxyAddress;
     }
 
     public static Builder newBuilder() {
@@ -120,6 +122,8 @@ public class ConnectParam {
 
         //used to set client_request_id in the grpc header uniquely for every request
         private ThreadLocal<String> clientRequestId;
+        
+        private String proxyAddress;
 
         protected Builder() {
         }
@@ -359,6 +363,17 @@ public class ConnectParam {
             this.clientRequestId = clientRequestId;
             return this;
         }
+        
+        /**
+         * Sets the proxy address for connections through a proxy server.
+         * 
+         * @param proxyAddress proxy server address in format "host:port"
+         * @return <code>Builder</code>
+         */
+        public Builder withProxyAddress(String proxyAddress) {
+            this.proxyAddress = proxyAddress;
+            return this;
+        }
 
         /**
          * Verifies parameters and creates a new {@link ConnectParam} instance.
@@ -418,4 +433,4 @@ public class ConnectParam {
             }
         }
     }
-}
+}

+ 5 - 0
sdk-core/src/main/java/io/milvus/v2/client/ConnectConfig.java

@@ -56,6 +56,7 @@ public class ConnectConfig {
     private String caPemPath;
     private String serverPemPath;
     private String serverName;
+    private String proxyAddress;
     @Builder.Default
     private Boolean secure = false;
     @Builder.Default
@@ -97,4 +98,8 @@ public class ConnectConfig {
         }
         return secure;
     }
+
+    public String  getProxyAddress(){
+        return proxyAddress;
+    }
 }

+ 48 - 1
sdk-core/src/main/java/io/milvus/v2/utils/ClientUtils.java

@@ -33,6 +33,9 @@ import io.grpc.stub.MetadataUtils;
 import io.milvus.client.MilvusServiceClient;
 import io.milvus.grpc.*;
 import io.milvus.v2.client.ConnectConfig;
+import io.grpc.HttpConnectProxiedSocketAddress;
+import io.grpc.ProxiedSocketAddress;
+import io.grpc.ProxyDetector;
 import org.apache.commons.lang3.StringUtils;
 import org.jetbrains.annotations.NotNull;
 import org.slf4j.Logger;
@@ -46,6 +49,8 @@ import java.nio.charset.StandardCharsets;
 import java.time.LocalDateTime;
 import java.util.Base64;
 import java.util.concurrent.TimeUnit;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
 
 public class ClientUtils {
     Logger logger = LoggerFactory.getLogger(ClientUtils.class);
@@ -73,6 +78,11 @@ public class ClientUtils {
                         .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
                         .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                
+                if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
+                    configureProxy(builder, connectConfig.getProxyAddress());
+                }
+                
                 if(connectConfig.isSecure()) {
                     builder.useTransportSecurity();
                 }
@@ -95,6 +105,11 @@ public class ClientUtils {
                         .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
                         .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+
+                if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
+                    configureProxy(builder, connectConfig.getProxyAddress());
+                }
+
                 if(connectConfig.isSecure()){
                     builder.useTransportSecurity();
                 }
@@ -102,7 +117,7 @@ public class ClientUtils {
             } else if (StringUtils.isNotEmpty(connectConfig.getClientPemPath())
                     && StringUtils.isNotEmpty(connectConfig.getClientKeyPath())
                     && StringUtils.isNotEmpty(connectConfig.getCaPemPath())) {
-                // tow-way tls
+                // two-way tls
                 SslContext sslContext = GrpcSslContexts.forClient()
                         .trustManager(new File(connectConfig.getCaPemPath()))
                         .keyManager(new File(connectConfig.getClientPemPath()), new File(connectConfig.getClientKeyPath()))
@@ -116,6 +131,11 @@ public class ClientUtils {
                         .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
                         .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                
+                if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
+                    configureProxy(builder, connectConfig.getProxyAddress());
+                }
+                
                 if (connectConfig.getSecure()) {
                     builder.useTransportSecurity();
                 }
@@ -133,6 +153,9 @@ public class ClientUtils {
                         .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
                         .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
                         .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                if (StringUtils.isNotEmpty(connectConfig.getProxyAddress())) {
+                    configureProxy(builder, connectConfig.getProxyAddress());
+                }
                 if(connectConfig.isSecure()){
                     builder.useTransportSecurity();
                 }
@@ -145,6 +168,30 @@ public class ClientUtils {
         return channel;
     }
 
+    /**
+     * Configures the proxy settings for a NettyChannelBuilder if proxy address is specified
+     * 
+     * @param builder NettyChannelBuilder to configure
+     * @param connectConfig Connection configuration containing proxy settings
+     */
+    public static void configureProxy(ManagedChannelBuilder builder, String proxyAddress) {
+        String[] hostPort = proxyAddress.split(":");
+        if (hostPort.length == 2) {
+            String proxyHost = hostPort[0];
+            int proxyPort = Integer.parseInt(hostPort[1]);
+
+            builder.proxyDetector(new ProxyDetector() {
+                @Override
+                public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) {
+                    return HttpConnectProxiedSocketAddress.newBuilder()
+                            .setProxyAddress(new InetSocketAddress(proxyHost, proxyPort))
+                            .setTargetAddress((InetSocketAddress) targetServerAddress)
+                            .build();
+                }
+            });
+        }
+    }
+
     private static JdkSslContext convertJavaSslContextToNetty(ConnectConfig connectConfig) {
         ApplicationProtocolConfig applicationProtocolConfig = new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.NONE,
                 ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT);