Browse Source

Adding support for milvus vector db processor conversion from java SSLContext to netty SslContext (#1068)

Signed-off-by: Noah Cover <ncover@datavolo.io>
Noah 9 months ago
parent
commit
a72b19a2bc

+ 3 - 0
src/main/java/io/milvus/v2/client/ConnectConfig.java

@@ -24,6 +24,7 @@ import lombok.Data;
 import lombok.NonNull;
 import lombok.NonNull;
 import lombok.experimental.SuperBuilder;
 import lombok.experimental.SuperBuilder;
 
 
+import javax.net.ssl.SSLContext;
 import java.net.URI;
 import java.net.URI;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 
 
@@ -57,6 +58,8 @@ public class ConnectConfig {
     @Builder.Default
     @Builder.Default
     private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
     private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
 
 
+    private SSLContext sslContext;
+
     public String getHost() {
     public String getHost() {
         URI uri = URI.create(this.uri);
         URI uri = URI.create(this.uri);
         return uri.getHost();
         return uri.getHost();

+ 31 - 1
src/main/java/io/milvus/v2/utils/ClientUtils.java

@@ -24,12 +24,17 @@ import io.grpc.ManagedChannelBuilder;
 import io.grpc.Metadata;
 import io.grpc.Metadata;
 import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
 import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
 import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
 import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
+import io.grpc.netty.shaded.io.netty.handler.ssl.ApplicationProtocolConfig;
+import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
+import io.grpc.netty.shaded.io.netty.handler.ssl.IdentityCipherSuiteFilter;
+import io.grpc.netty.shaded.io.netty.handler.ssl.JdkSslContext;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
 import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
 import io.grpc.stub.MetadataUtils;
 import io.grpc.stub.MetadataUtils;
 import io.milvus.client.MilvusServiceClient;
 import io.milvus.client.MilvusServiceClient;
 import io.milvus.grpc.*;
 import io.milvus.grpc.*;
 import io.milvus.v2.client.ConnectConfig;
 import io.milvus.v2.client.ConnectConfig;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.StringUtils;
+import org.jetbrains.annotations.NotNull;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
@@ -57,7 +62,25 @@ public class ClientUtils {
         }
         }
 
 
         try {
         try {
-            if (StringUtils.isNotEmpty(connectConfig.getServerPemPath())) {
+            if (connectConfig.getSslContext() != null) {
+                // sslContext from connect config
+                NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort())
+                        .overrideAuthority(connectConfig.getServerName())
+                        .sslContext(convertJavaSslContextToNetty(connectConfig))
+                        .maxInboundMessageSize(Integer.MAX_VALUE)
+                        .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
+                        .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                if(connectConfig.isSecure()) {
+                    builder.useTransportSecurity();
+                }
+                if (StringUtils.isNotEmpty(connectConfig.getServerName())) {
+                    builder.overrideAuthority(connectConfig.getServerName());
+                }
+                channel = builder.build();
+            } else if (StringUtils.isNotEmpty(connectConfig.getServerPemPath())) {
                 // one-way tls
                 // one-way tls
                 SslContext sslContext = GrpcSslContexts.forClient()
                 SslContext sslContext = GrpcSslContexts.forClient()
                         .trustManager(new File(connectConfig.getServerPemPath()))
                         .trustManager(new File(connectConfig.getServerPemPath()))
@@ -122,6 +145,13 @@ public class ClientUtils {
         return channel;
         return channel;
     }
     }
 
 
+    private static JdkSslContext convertJavaSslContextToNetty(ConnectConfig connectConfig) {
+        ApplicationProtocolConfig applicationProtocolConfig = new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.NONE,
+                ApplicationProtocolConfig.SelectorFailureBehavior.FATAL_ALERT, ApplicationProtocolConfig.SelectedListenerFailureBehavior.FATAL_ALERT);
+        return new JdkSslContext(connectConfig.getSslContext(), true, null,
+                IdentityCipherSuiteFilter.INSTANCE, applicationProtocolConfig, ClientAuth.NONE, null, false);
+    }
+
     public void checkDatabaseExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName) {
     public void checkDatabaseExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName) {
         String title = String.format("Check database %s exist", dbName);
         String title = String.format("Check database %s exist", dbName);
         ListDatabasesRequest listDatabasesRequest = ListDatabasesRequest.newBuilder().build();
         ListDatabasesRequest listDatabasesRequest = ListDatabasesRequest.newBuilder().build();