瀏覽代碼

Support TLS (#559)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 年之前
父節點
當前提交
65ae173dd7

+ 5 - 0
.gitignore

@@ -29,3 +29,8 @@ hs_err_pid*
 target/
 volumes/
 *.iml
+
+# Example files
+examples/main/java/io/milvus/tls/*
+!examples/main/java/io/milvus/tls/gen.sh
+!examples/main/java/io/milvus/tls/openssl.cnf

+ 19 - 0
examples/main/java/io/milvus/RBACExample.java

@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
 package io.milvus;
 
 import io.milvus.client.MilvusServiceClient;

+ 106 - 0
examples/main/java/io/milvus/TLSExample.java

@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus;
+
+import com.alibaba.fastjson.JSONObject;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.grpc.*;
+import io.milvus.param.*;
+import io.milvus.param.collection.*;
+import io.milvus.param.dml.*;
+import io.milvus.param.index.*;
+import io.milvus.response.*;
+import java.util.*;
+
+
+// Note: read the following description before running this example
+// 1. cmd into the "tls" folder, generate certificate by the following commands.
+// (more details read the https://milvus.io/docs/tls.md)
+//   chmod +x gen.sh
+//   ./gen.sh
+//
+// 2. Configure the file paths of server.pem, server.key, and ca.pem for the server in config/milvus.yaml.
+//    Set tlsMode to 1 for one-way authentication. Set tlsMode to 2 for two-way authentication.
+// (read the doc to know how to config milvus: https://milvus.io/docs/configure-docker.md)
+//    tls:
+//        serverPemPath: [path_to_tls]/tls/server.pem
+//        serverKeyPath: [path_to_tls]/tls/server.key
+//        caPemPath: [path_to_tls]/tls/ca.pem
+//
+//    common:
+//        security:
+//        tlsMode: 2
+//
+// 3. Start milvus server
+// 4. Run this example.
+//    Connect server by oneWayAuth() if the server tlsMode=1, connect server by twoWayAuth() if the server tlsMode=2.
+//
+public class TLSExample {
+
+    private static void oneWayAuth() {
+        String path = ClassLoader.getSystemResource("").getPath();
+        ConnectParam connectParam = ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .withServerName("localhost")
+                .withServerPemPath(path + "/tls/server.pem")
+                .build();
+        MilvusServiceClient milvusClient = new MilvusServiceClient(connectParam);
+
+        R<CheckHealthResponse> health = milvusClient.checkHealth();
+        if (health.getStatus() != R.Status.Success.getCode()) {
+            throw new RuntimeException(health.getMessage());
+        } else {
+            System.out.println(health);
+        }
+    }
+
+    private static void twoWayAuth() {
+        String path = ClassLoader.getSystemResource("").getPath();
+        ConnectParam connectParam = ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .withServerName("localhost")
+                .withCaPemPath(path + "/tls/ca.pem")
+                .withClientKeyPath(path + "/tls/client.key")
+                .withClientPemPath(path + "/tls/client.pem")
+                .build();
+        MilvusServiceClient milvusClient = new MilvusServiceClient(connectParam);
+
+        R<CheckHealthResponse> health = milvusClient.checkHealth();
+        if (health.getStatus() != R.Status.Success.getCode()) {
+            throw new RuntimeException(health.getMessage());
+        } else {
+            System.out.println(health);
+        }
+    }
+
+    // tlsMode=1, set oneWay=true
+    // tlsMode=2, set oneWay=false
+    private static final boolean oneWay = false;
+
+    public static void main(String[] args) {
+        if (oneWay) {
+            oneWayAuth();
+        } else {
+            twoWayAuth();
+        }
+    }
+}

+ 24 - 0
examples/main/java/io/milvus/tls/gen.sh

@@ -0,0 +1,24 @@
+Country="CN"
+State="Shanghai"
+Location="Shanghai"
+Organization="milvus"
+Organizational="milvus"
+CommonName="localhost"
+
+echo "generate ca.key"
+openssl genrsa -out ca.key 2048
+
+echo "generate ca.pem"
+openssl req -new -x509 -key ca.key -out ca.pem -days 3650 -subj "/C=$Country/ST=$State/L=$Location/O=$Organization/OU=$Organizational/CN=$CommonName"
+
+echo "generate server SAN certificate"
+openssl genpkey -algorithm RSA -out server.key
+openssl req -new -nodes -key server.key -out server.csr -days 3650 -subj "/C=$Country/O=$Organization/OU=$Organizational/CN=$CommonName" -config ./openssl.cnf -extensions v3_req
+openssl x509 -req -days 3650 -in server.csr -out server.pem -CA ca.pem -CAkey ca.key -CAcreateserial -extfile ./openssl.cnf -extensions v3_req
+
+echo "generate client SAN certificate"
+openssl genpkey -algorithm RSA -out client.key
+openssl req -new -nodes -key client.key -out client.csr -days 3650 -subj "/C=$Country/O=$Organization/OU=$Organizational/CN=$CommonName" -config ./openssl.cnf -extensions v3_req
+openssl x509 -req -days 3650 -in client.csr -out client.pem -CA ca.pem -CAkey ca.key -CAcreateserial -extfile ./openssl.cnf -extensions v3_req
+
+

+ 213 - 0
examples/main/java/io/milvus/tls/openssl.cnf

@@ -0,0 +1,213 @@
+
+HOME			= .
+RANDFILE		= $ENV::HOME/.rnd
+
+oid_section		= new_oids
+
+
+[ new_oids ]
+
+
+tsa_policy1 = 1.2.3.4.1
+tsa_policy2 = 1.2.3.4.5.6
+tsa_policy3 = 1.2.3.4.5.7
+
+[ ca ]
+default_ca	= CA_default		# The default ca section
+
+[ CA_default ]
+
+dir		= ./demoCA		# Where everything is kept
+certs		= $dir/certs		# Where the issued certs are kept
+crl_dir		= $dir/crl		# Where the issued crl are kept
+database	= $dir/index.txt	# database index file.
+					# several ctificates with same subject.
+new_certs_dir	= $dir/newcerts		# default place for new certs.
+
+certificate	= $dir/cacert.pem 	# The CA certificate
+serial		= $dir/serial 		# The current serial number
+crlnumber	= $dir/crlnumber	# the current crl number
+					# must be commented out to leave a V1 CRL
+crl		= $dir/crl.pem 		# The current CRL
+private_key	= $dir/private/cakey.pem# The private key
+RANDFILE	= $dir/private/.rand	# private random number file
+
+x509_extensions	= usr_cert		# The extentions to add to the cert
+
+name_opt 	= ca_default		# Subject Name options
+cert_opt 	= ca_default		# Certificate field options
+
+copy_extensions = copy
+
+
+default_days	= 365			# how long to certify for
+default_crl_days= 30			# how long before next CRL
+default_md	= default		# use public key default MD
+preserve	= no			# keep passed DN ordering
+
+policy		= policy_match
+
+[ policy_match ]
+countryName		= match
+stateOrProvinceName	= match
+organizationName	= match
+organizationalUnitName	= optional
+commonName		= supplied
+emailAddress		= optional
+
+[ policy_anything ]
+countryName		= optional
+stateOrProvinceName	= optional
+localityName		= optional
+organizationName	= optional
+organizationalUnitName	= optional
+commonName		= supplied
+emailAddress		= optional
+
+[ req ]
+default_bits		= 2048
+default_keyfile 	= privkey.pem
+distinguished_name	= req_distinguished_name
+attributes		= req_attributes
+x509_extensions	= v3_ca	# The extentions to add to the self signed cert
+
+
+string_mask = utf8only
+
+req_extensions = v3_req # The extensions to add to a certificate request
+
+[ req_distinguished_name ]
+countryName			= Country Name (2 letter code)
+countryName_default		= AU
+countryName_min			= 2
+countryName_max			= 2
+
+stateOrProvinceName		= State or Province Name (full name)
+stateOrProvinceName_default	= Some-State
+
+localityName			= Locality Name (eg, city)
+
+0.organizationName		= Organization Name (eg, company)
+0.organizationName_default	= Internet Widgits Pty Ltd
+
+
+organizationalUnitName		= Organizational Unit Name (eg, section)
+
+commonName			= Common Name (e.g. server FQDN or YOUR name)
+commonName_max			= 64
+
+emailAddress			= Email Address
+emailAddress_max		= 64
+
+
+[ req_attributes ]
+challengePassword		= A challenge password
+challengePassword_min		= 4
+challengePassword_max		= 20
+
+unstructuredName		= An optional company name
+
+[ usr_cert ]
+
+
+
+basicConstraints=CA:FALSE
+
+
+
+
+
+
+
+nsComment			= "OpenSSL Generated Certificate"
+
+subjectKeyIdentifier=hash
+authorityKeyIdentifier=keyid,issuer
+
+
+
+
+
+[ v3_req ]
+
+
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+subjectAltName = @alt_names
+
+[ alt_names ]
+DNS.1 = localhost
+DNS.2 = *.ronething.cn
+DNS.3 = *.ronething.com
+
+[ v3_ca ]
+
+
+
+
+
+subjectKeyIdentifier=hash
+
+authorityKeyIdentifier=keyid:always,issuer
+
+basicConstraints = CA:true
+
+
+
+
+
+[ crl_ext ]
+
+
+authorityKeyIdentifier=keyid:always
+
+[ proxy_cert_ext ]
+
+
+basicConstraints=CA:FALSE
+
+
+
+
+
+
+
+nsComment			= "OpenSSL Generated Certificate"
+
+subjectKeyIdentifier=hash
+authorityKeyIdentifier=keyid,issuer
+
+
+
+
+proxyCertInfo=critical,language:id-ppl-anyLanguage,pathlen:3,policy:foo
+
+[ tsa ]
+
+default_tsa = tsa_config1	# the default TSA section
+
+[ tsa_config1 ]
+
+dir		= ./demoCA		# TSA root directory
+serial		= $dir/tsaserial	# The current serial number (mandatory)
+crypto_device	= builtin		# OpenSSL engine to use for signing
+signer_cert	= $dir/tsacert.pem 	# The TSA signing certificate
+					# (optional)
+certs		= $dir/cacert.pem	# Certificate chain to include in reply
+					# (optional)
+signer_key	= $dir/private/tsakey.pem # The TSA private key (optional)
+
+default_policy	= tsa_policy1		# Policy if request did not specify it
+					# (optional)
+other_policies	= tsa_policy2, tsa_policy3	# acceptable policies (optional)
+digests		= md5, sha1		# Acceptable message digests (mandatory)
+accuracy	= secs:1, millisecs:500, microsecs:100	# (optional)
+clock_precision_digits  = 0	# number of digits after dot. (optional)
+ordering		= yes	# Is ordering defined for timestamps?
+				# (optional, default: no)
+tsa_name		= yes	# Must the TSA name be included in the reply?
+				# (optional, default: no)
+ess_cert_id_chain	= no	# Must the ESS cert id chain be included?
+				# (optional, default: no)
+

+ 1 - 1
examples/pom.xml

@@ -64,7 +64,7 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>2.2.8</version>
+            <version>2.2.9</version>
         </dependency>
         <dependency>
             <groupId>com.google.code.gson</groupId>

+ 1 - 2
pom.xml

@@ -111,9 +111,8 @@
     <dependencies>
         <dependency>
             <groupId>io.grpc</groupId>
-            <artifactId>grpc-netty-shaded</artifactId>
+            <artifactId>grpc-netty</artifactId>
             <version>${grpc.version}</version>
-            <scope>runtime</scope>
         </dependency>
         <dependency>
             <groupId>io.grpc</groupId>

+ 6 - 8
src/main/java/io/milvus/client/AbstractMilvusGrpcClient.java

@@ -25,7 +25,6 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import io.grpc.StatusRuntimeException;
-import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.common.utils.JacksonUtils;
 import io.milvus.common.utils.VectorUtils;
 import io.milvus.exception.*;
@@ -325,7 +324,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
 
     private <T> R<T> failedStatus(String requestName, io.milvus.grpc.Status status) {
         String reason = status.getReason();
-        if (reason == null || reason.isEmpty()) {
+        if (StringUtils.isEmpty(reason)) {
             reason = "error code: " + status.getErrorCode().toString();
         }
         logError(requestName + " failed:\n{}", reason);
@@ -481,10 +480,8 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
                     .setDescription(requestParam.getDescription())
                     .setEnableDynamicField(requestParam.isEnableDynamicField());
 
-            long fieldID = 0;
             for (FieldType fieldType : requestParam.getFieldTypes()) {
                 collectionSchemaBuilder.addFields(ParamUtils.ConvertField(fieldType));
-                fieldID++;
             }
 
             // Construct CreateCollectionRequest
@@ -2982,6 +2979,7 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
     }
 
     @Override
+    @SuppressWarnings("unchecked")
     public R<SearchResponse> search(SearchSimpleParam requestParam) {
         if (!clientIsReady()) {
             return R.failed(new ClientNotConnectedException("Client rpc channel is not ready"));
@@ -3041,25 +3039,25 @@ public abstract class AbstractMilvusGrpcClient implements MilvusClient {
     }
 
     ///////////////////// Log Functions//////////////////////
-    private void logDebug(String msg, Object... params) {
+    protected void logDebug(String msg, Object... params) {
         if (logLevel.ordinal() <= LogLevel.Debug.ordinal()) {
             logger.debug(msg, params);
         }
     }
 
-    private void logInfo(String msg, Object... params) {
+    protected void logInfo(String msg, Object... params) {
         if (logLevel.ordinal() <= LogLevel.Info.ordinal()) {
             logger.info(msg, params);
         }
     }
 
-    private void logWarning(String msg, Object... params) {
+    protected void logWarning(String msg, Object... params) {
         if (logLevel.ordinal() <= LogLevel.Warning.ordinal()) {
             logger.warn(msg, params);
         }
     }
 
-    private void logError(String msg, Object... params) {
+    protected void logError(String msg, Object... params) {
         if (logLevel.ordinal() <= LogLevel.Error.ordinal()) {
             logger.error(msg, params);
         }

+ 1 - 1
src/main/java/io/milvus/client/MilvusMultiServiceClient.java

@@ -94,7 +94,7 @@ public class MilvusMultiServiceClient implements MilvusClient {
                 .withKeepAliveTime(keepAliveTimeMs, TimeUnit.MILLISECONDS)
                 .withKeepAliveTimeout(keepAliveTimeoutMs, TimeUnit.MILLISECONDS)
                 .keepAliveWithoutCalls(keepAliveWithoutCalls)
-                .secure(secure)
+                .withSecure(secure)
                 .withIdleTimeout(idleTimeoutMs, TimeUnit.MILLISECONDS)
                 .withAuthorization(multiConnectParam.getAuthorization())
                 .build();

+ 69 - 15
src/main/java/io/milvus/client/MilvusServiceClient.java

@@ -25,16 +25,20 @@ import io.milvus.grpc.MilvusServiceGrpc;
 import io.milvus.param.ConnectParam;
 
 import io.milvus.param.LogLevel;
-import io.milvus.param.R;
-import io.milvus.param.RpcStatus;
 import lombok.NonNull;
 import org.apache.commons.lang3.StringUtils;
 
+import java.io.IOException;
 import java.util.concurrent.TimeUnit;
+import java.io.File;
+
+import io.grpc.netty.GrpcSslContexts;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.handler.ssl.SslContext;
 
 public class MilvusServiceClient extends AbstractMilvusGrpcClient {
 
-    private final ManagedChannel channel;
+    private ManagedChannel channel;
     private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
     private final MilvusServiceGrpc.MilvusServiceFutureStub futureStub;
     private final long rpcDeadlineMs;
@@ -48,20 +52,70 @@ public class MilvusServiceClient extends AbstractMilvusGrpcClient {
             metadata.put(Metadata.Key.of("dbname", Metadata.ASCII_STRING_MARSHALLER), connectParam.getDatabaseName());
         }
 
-        ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
-                .usePlaintext()
-                .maxInboundMessageSize(Integer.MAX_VALUE)
-                .keepAliveTime(connectParam.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
-                .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
-                .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
-                .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
-                .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
-
-        if(connectParam.isSecure()){
-            builder.useTransportSecurity();
+        try {
+            if (StringUtils.isNotEmpty(connectParam.getServerPemPath())) {
+                // one-way tls
+                SslContext sslContext = GrpcSslContexts.forClient()
+                        .trustManager(new File(connectParam.getServerPemPath()))
+                        .build();
+
+                NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
+                        .overrideAuthority(connectParam.getServerName())
+                        .sslContext(sslContext)
+                        .maxInboundMessageSize(Integer.MAX_VALUE)
+                        .keepAliveTime(connectParam.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
+                        .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                if(connectParam.isSecure()){
+                    builder.useTransportSecurity();
+                }
+                channel = builder.build();
+            } else if (StringUtils.isNotEmpty(connectParam.getClientPemPath())
+                    && StringUtils.isNotEmpty(connectParam.getClientKeyPath())
+                    && StringUtils.isNotEmpty(connectParam.getCaPemPath())) {
+                // tow-way tls
+                SslContext sslContext = GrpcSslContexts.forClient()
+                        .trustManager(new File(connectParam.getCaPemPath()))
+                        .keyManager(new File(connectParam.getClientPemPath()), new File(connectParam.getClientKeyPath()))
+                        .build();
+
+                NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
+                        .sslContext(sslContext)
+                        .maxInboundMessageSize(Integer.MAX_VALUE)
+                        .keepAliveTime(connectParam.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
+                        .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                if(connectParam.isSecure()){
+                    builder.useTransportSecurity();
+                }
+                if (StringUtils.isNotEmpty(connectParam.getServerName())) {
+                    builder.overrideAuthority(connectParam.getServerName());
+                }
+                channel = builder.build();
+            } else {
+                // no tls
+                ManagedChannelBuilder<?> builder = ManagedChannelBuilder.forAddress(connectParam.getHost(), connectParam.getPort())
+                        .usePlaintext()
+                        .maxInboundMessageSize(Integer.MAX_VALUE)
+                        .keepAliveTime(connectParam.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveTimeout(connectParam.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .keepAliveWithoutCalls(connectParam.isKeepAliveWithoutCalls())
+                        .idleTimeout(connectParam.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
+                        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
+                if(connectParam.isSecure()){
+                    builder.useTransportSecurity();
+                }
+                channel = builder.build();
+            }
+        } catch (IOException e) {
+            logError("Failed to open credentials file, error:{}\n", e.getMessage());
         }
-        channel = builder.build();
 
+        assert channel != null;
         blockingStub = MilvusServiceGrpc.newBlockingStub(channel);
         futureStub = MilvusServiceGrpc.newFutureStub(channel);
     }

+ 90 - 64
src/main/java/io/milvus/param/ConnectParam.java

@@ -20,7 +20,9 @@
 package io.milvus.param;
 
 import io.milvus.exception.ParamException;
+import lombok.Getter;
 import lombok.NonNull;
+import lombok.ToString;
 import org.apache.commons.lang3.StringUtils;
 
 import java.nio.charset.StandardCharsets;
@@ -30,6 +32,8 @@ import java.util.concurrent.TimeUnit;
 /**
  * Parameters for client connection.
  */
+@Getter
+@ToString
 public class ConnectParam {
     private final String host;
     private final int port;
@@ -44,8 +48,13 @@ public class ConnectParam {
     private final boolean secure;
     private final long idleTimeoutMs;
     private final String authorization;
+    private final String clientKeyPath;
+    private final String clientPemPath;
+    private final String caPemPath;
+    private final String serverPemPath;
+    private final String serverName;
 
-    private ConnectParam(@NonNull Builder builder) {
+    protected ConnectParam(@NonNull Builder builder) {
         this.host = builder.host;
         this.port = builder.port;
         this.token = builder.token;
@@ -59,50 +68,11 @@ public class ConnectParam {
         this.rpcDeadlineMs = builder.rpcDeadlineMs;
         this.secure = builder.secure;
         this.authorization = builder.authorization;
-    }
-
-    public String getHost() {
-        return host;
-    }
-
-    public int getPort() {
-        return port;
-    }
-
-    public long getConnectTimeoutMs() {
-        return connectTimeoutMs;
-    }
-
-    public long getKeepAliveTimeMs() {
-        return keepAliveTimeMs;
-    }
-
-    public long getKeepAliveTimeoutMs() {
-        return keepAliveTimeoutMs;
-    }
-
-    public boolean isKeepAliveWithoutCalls() {
-        return keepAliveWithoutCalls;
-    }
-    public long getIdleTimeoutMs() {
-        return idleTimeoutMs;
-    }
-
-    public long getRpcDeadlineMs() {
-        return rpcDeadlineMs;
-    }
-
-    public boolean isSecure() {
-        return secure;
-    }
-
-
-    public String getAuthorization() {
-        return authorization;
-    }
-
-    public String getDatabaseName() {
-        return databaseName;
+        this.clientKeyPath = builder.clientKeyPath;
+        this.clientPemPath = builder.clientPemPath;
+        this.caPemPath = builder.caPemPath;
+        this.serverPemPath = builder.serverPemPath;
+        this.serverName = builder.serverName;
     }
 
     public static Builder newBuilder() {
@@ -112,6 +82,7 @@ public class ConnectParam {
     /**
      * Builder for {@link ConnectParam}
      */
+    @Getter
     public static class Builder {
         private String host = "localhost";
         private int port = 19530;
@@ -124,11 +95,17 @@ public class ConnectParam {
         private boolean keepAliveWithoutCalls = false;
         private long rpcDeadlineMs = 0; // Disabling deadline
 
-        private boolean secure = false;
+        private String clientKeyPath;
+        private String clientPemPath;
+        private String caPemPath;
+        private String serverPemPath;
+        private String serverName;
+
+        protected boolean secure = false;
         private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
         private String authorization = Base64.getEncoder().encodeToString("root:milvus".getBytes(StandardCharsets.UTF_8));
 
-        private Builder() {
+        protected Builder() {
         }
 
         /**
@@ -167,7 +144,7 @@ public class ConnectParam {
         /**
          * Sets the uri
          *
-         * @param uri
+         * @param uri the uri of Milvus instance
          * @return <code>Builder</code>
          */
         public Builder withUri(String uri) {
@@ -178,7 +155,7 @@ public class ConnectParam {
         /**
          * Sets the token
          *
-         * @param token
+         * @param token serving as the key for identification and authentication purposes.
          * @return <code>Builder</code>
          */
         public Builder withToken(String token) {
@@ -239,6 +216,7 @@ public class ConnectParam {
          * @param enable true keep-alive
          * @return <code>Builder</code>
          */
+        @java.lang.Deprecated
         public Builder secure(boolean enable) {
             secure = enable;
             return this;
@@ -282,7 +260,7 @@ public class ConnectParam {
         }
 
         /**
-         * Sets secure the authorization for this connection
+         * Sets secure the authorization for this connection, set to True to enable TLS
          * @param secure boolean
          * @return <code>Builder</code>
          */
@@ -301,12 +279,69 @@ public class ConnectParam {
             return this;
         }
 
+        /**
+         * Set the client.key path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param clientKeyPath path of client.key
+         * @return <code>Builder</code>
+         */
+        public Builder withClientKeyPath(@NonNull String clientKeyPath) {
+            this.clientKeyPath = clientKeyPath;
+            return this;
+        }
+
+        /**
+         * Set the client.pem path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param clientPemPath path of client.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withClientPemPath(@NonNull String clientPemPath) {
+            this.clientPemPath = clientPemPath;
+            return this;
+        }
+
+        /**
+         * Set the ca.pem path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param caPemPath path of ca.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withCaPemPath(@NonNull String caPemPath) {
+            this.caPemPath = caPemPath;
+            return this;
+        }
+
+        /**
+         * Set the server.pem path for tls one-way authentication, only takes effect when "secure" is True.
+         * @param serverPemPath path of server.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withServerPemPath(@NonNull String serverPemPath) {
+            this.serverPemPath = serverPemPath;
+            return this;
+        }
+
+        /**
+         * Set target name override for SSL host name checking, only takes effect when "secure" is True.
+         * Note: this value is passed to grpc.ssl_target_name_override
+         * @param serverName override name for SSL host
+         * @return <code>Builder</code>
+         */
+        public Builder withServerName(@NonNull String serverName) {
+            this.serverName = serverName;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link ConnectParam} instance.
          *
          * @return {@link ConnectParam}
          */
         public ConnectParam build() throws ParamException {
+            verify();
+
+            return new ConnectParam(this);
+        }
+
+        protected void verify() throws ParamException {
             ParamUtils.CheckNullEmptyString(host, "Host name");
             if (StringUtils.isNotEmpty(uri)) {
                 io.milvus.utils.URLParser result = new io.milvus.utils.URLParser(uri);
@@ -321,6 +356,7 @@ public class ConnectParam {
                 if (!token.contains(":")) {
                     this.port = 443;
                 }
+                this.secure = true; //
             }
 
             if (port < 0 || port > 0xFFFF) {
@@ -343,20 +379,10 @@ public class ConnectParam {
                 throw new ParamException("Idle timeout must be positive!");
             }
 
-            return new ConnectParam(this);
+            if (StringUtils.isNotEmpty(serverPemPath) || StringUtils.isNotEmpty(caPemPath)
+                    || StringUtils.isNotEmpty(clientPemPath) || StringUtils.isNotEmpty(clientKeyPath)) {
+                secure = true;
+            }
         }
     }
-
-    /**
-     * Constructs a <code>String</code> by {@link ConnectParam} instance.
-     *
-     * @return <code>String</code>
-     */
-    @Override
-    public String toString() {
-        return "ConnectParam{" +
-                "host='" + host + '\'' +
-                ", port='" + port +
-                '}';
-    }
 }

+ 149 - 111
src/main/java/io/milvus/param/MultiConnectParam.java

@@ -2,11 +2,11 @@ package io.milvus.param;
 
 import com.google.common.collect.Lists;
 import io.milvus.exception.ParamException;
+import lombok.Getter;
 import lombok.NonNull;
+import lombok.ToString;
 import org.apache.commons.collections4.CollectionUtils;
 
-import java.nio.charset.StandardCharsets;
-import java.util.Base64;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 
@@ -16,63 +16,16 @@ import static io.milvus.common.constant.MilvusClientConstant.MilvusConsts.HOST_H
 /**
  * Parameters for client connection of multi server.
  */
-public class MultiConnectParam {
+@Getter
+@ToString
+public class MultiConnectParam extends ConnectParam {
     private final List<ServerAddress> hosts;
     private final QueryNodeSingleSearch queryNodeSingleSearch;
-    private final long connectTimeoutMs;
-    private final long keepAliveTimeMs;
-    private final long keepAliveTimeoutMs;
-    private final boolean keepAliveWithoutCalls;
-    private final boolean secure;
-    private final long idleTimeoutMs;
-    private final String authorization;
 
     private MultiConnectParam(@NonNull Builder builder) {
+        super(builder);
         this.hosts = builder.hosts;
         this.queryNodeSingleSearch = builder.queryNodeSingleSearch;
-        this.connectTimeoutMs = builder.connectTimeoutMs;
-        this.keepAliveTimeMs = builder.keepAliveTimeMs;
-        this.keepAliveTimeoutMs = builder.keepAliveTimeoutMs;
-        this.keepAliveWithoutCalls = builder.keepAliveWithoutCalls;
-        this.secure = builder.secure;
-        this.idleTimeoutMs = builder.idleTimeoutMs;
-        this.authorization = builder.authorization;
-    }
-
-    public List<ServerAddress> getHosts() {
-        return hosts;
-    }
-
-    public QueryNodeSingleSearch getQueryNodeSingleSearch() {
-        return queryNodeSingleSearch;
-    }
-
-    public long getConnectTimeoutMs() {
-        return connectTimeoutMs;
-    }
-
-    public long getKeepAliveTimeMs() {
-        return keepAliveTimeMs;
-    }
-
-    public long getKeepAliveTimeoutMs() {
-        return keepAliveTimeoutMs;
-    }
-
-    public boolean isKeepAliveWithoutCalls() {
-        return keepAliveWithoutCalls;
-    }
-
-    public boolean isSecure() {
-        return secure;
-    }
-
-    public long getIdleTimeoutMs() {
-        return idleTimeoutMs;
-    }
-
-    public String getAuthorization() {
-        return authorization;
     }
 
     public static Builder newBuilder() {
@@ -82,16 +35,9 @@ public class MultiConnectParam {
     /**
      * Builder for {@link MultiConnectParam}
      */
-    public static class Builder {
+    public static class Builder extends ConnectParam.Builder {
         private List<ServerAddress> hosts;
         private QueryNodeSingleSearch queryNodeSingleSearch;
-        private long connectTimeoutMs = 10000;
-        private long keepAliveTimeMs = Long.MAX_VALUE; // Disabling keep alive
-        private long keepAliveTimeoutMs = 20000;
-        private boolean keepAliveWithoutCalls = false;
-        private boolean secure = false;
-        private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS);
-        private String authorization = "";
 
         private Builder() {
         }
@@ -118,6 +64,61 @@ public class MultiConnectParam {
             return this;
         }
 
+        /**
+         * Sets the host name/address.
+         *
+         * @param host host name/address
+         * @return <code>Builder</code>
+         */
+        public Builder withHost(@NonNull String host) {
+            super.withHost(host);
+            return this;
+        }
+
+        /**
+         * Sets the connection port. Port value must be greater than zero and less than 65536.
+         *
+         * @param port port value
+         * @return <code>Builder</code>
+         */
+        public Builder withPort(int port)  {
+            super.withPort(port);
+            return this;
+        }
+
+        /**
+         * Sets the database name.
+         *
+         * @param databaseName databaseName
+         * @return <code>Builder</code>
+         */
+        public Builder withDatabaseName(@NonNull String databaseName) {
+            super.withDatabaseName(databaseName);
+            return this;
+        }
+
+        /**
+         * Sets the uri
+         *
+         * @param uri the uri of Milvus instance
+         * @return <code>Builder</code>
+         */
+        public Builder withUri(String uri) {
+            super.withUri(uri);
+            return this;
+        }
+
+        /**
+         * Sets the token
+         *
+         * @param token serving as the key for identification and authentication purposes.
+         * @return <code>Builder</code>
+         */
+        public Builder withToken(String token) {
+            super.withToken(token);
+            return this;
+        }
+
         /**
          * Sets the connection timeout value of client channel. The timeout value must be greater than zero.
          *
@@ -126,7 +127,7 @@ public class MultiConnectParam {
          * @return <code>Builder</code>
          */
         public Builder withConnectTimeout(long connectTimeout, @NonNull TimeUnit timeUnit) {
-            this.connectTimeoutMs = timeUnit.toMillis(connectTimeout);
+            super.withConnectTimeout(connectTimeout, timeUnit);
             return this;
         }
 
@@ -138,7 +139,7 @@ public class MultiConnectParam {
          * @return <code>Builder</code>
          */
         public Builder withKeepAliveTime(long keepAliveTime, @NonNull TimeUnit timeUnit) {
-            this.keepAliveTimeMs = timeUnit.toMillis(keepAliveTime);
+            super.withKeepAliveTime(keepAliveTime, timeUnit);
             return this;
         }
 
@@ -150,7 +151,7 @@ public class MultiConnectParam {
          * @return <code>Builder</code>
          */
         public Builder withKeepAliveTimeout(long keepAliveTimeout, @NonNull TimeUnit timeUnit) {
-            this.keepAliveTimeoutMs = timeUnit.toNanos(keepAliveTimeout);
+            super.withKeepAliveTimeout(keepAliveTimeout, timeUnit);
             return this;
         }
 
@@ -161,51 +162,115 @@ public class MultiConnectParam {
          * @return <code>Builder</code>
          */
         public Builder keepAliveWithoutCalls(boolean enable) {
-            keepAliveWithoutCalls = enable;
+            super.keepAliveWithoutCalls(enable);
             return this;
         }
 
         /**
-         * Enables the secure for client channel.
+         * Sets the idle timeout value of client channel. The timeout value must be larger than zero.
          *
-         * @param enable true keep-alive
+         * @param idleTimeout timeout value
+         * @param timeUnit timeout unit
          * @return <code>Builder</code>
          */
-        public Builder secure(boolean enable) {
-            secure = enable;
+        public Builder withIdleTimeout(long idleTimeout, @NonNull TimeUnit timeUnit) {
+            super.withIdleTimeout(idleTimeout, timeUnit);
             return this;
         }
 
         /**
-         * Sets secure the authorization for this connection
+         * Set a deadline for how long you are willing to wait for a reply from the server.
+         * With a deadline setting, the client will wait when encounter fast RPC fail caused by network fluctuations.
+         * The deadline value must be larger than or equal to zero. Default value is 0, deadline is disabled.
+         *
+         * @param deadline deadline value
+         * @param timeUnit deadline unit
+         * @return <code>Builder</code>
+         */
+        public Builder withRpcDeadline(long deadline, @NonNull TimeUnit timeUnit) {
+            super.withRpcDeadline(deadline, timeUnit);
+            return this;
+        }
+
+        /**
+         * Sets the username and password for this connection
+         * @param username current user
+         * @param password password
+         * @return <code>Builder</code>
+         */
+        public Builder withAuthorization(String username, String password) {
+            super.withAuthorization(username, password);
+            return this;
+        }
+
+        /**
+         * Sets secure the authorization for this connection, set to True to enable TLS
          * @param secure boolean
          * @return <code>Builder</code>
          */
         public Builder withSecure(boolean secure) {
-            this.secure = secure;
+            super.withSecure(secure);
             return this;
         }
 
         /**
-         * Sets the idle timeout value of client channel. The timeout value must be larger than zero.
-         *
-         * @param idleTimeout timeout value
-         * @param timeUnit timeout unit
+         * Sets the secure for this connection
+         * @param authorization the authorization info that has included the encoded username and password info
          * @return <code>Builder</code>
          */
-        public Builder withIdleTimeout(long idleTimeout, @NonNull TimeUnit timeUnit) {
-            this.idleTimeoutMs = timeUnit.toMillis(idleTimeout);
+        public Builder withAuthorization(@NonNull String authorization) {
+            super.withAuthorization(authorization);
             return this;
         }
 
         /**
-         * Sets the username and password for this connection
-         * @param username current user
-         * @param password password
+         * Set the client.key path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param clientKeyPath path of client.key
          * @return <code>Builder</code>
          */
-        public Builder withAuthorization(@NonNull String username, @NonNull String password) {
-            this.authorization = Base64.getEncoder().encodeToString(String.format("%s:%s", username, password).getBytes(StandardCharsets.UTF_8));
+        public Builder withClientKeyPath(@NonNull String clientKeyPath) {
+            super.withClientKeyPath(clientKeyPath);
+            return this;
+        }
+
+        /**
+         * Set the client.pem path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param clientPemPath path of client.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withClientPemPath(@NonNull String clientPemPath) {
+            super.withClientPemPath(clientPemPath);
+            return this;
+        }
+
+        /**
+         * Set the ca.pem path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param caPemPath path of ca.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withCaPemPath(@NonNull String caPemPath) {
+            super.withCaPemPath(caPemPath);
+            return this;
+        }
+
+        /**
+         * Set the server.pem path for tls two-way authentication, only takes effect when "secure" is True.
+         * @param serverPemPath path of server.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withServerPemPath(@NonNull String serverPemPath) {
+            super.withServerPemPath(serverPemPath);
+            return this;
+        }
+
+        /**
+         * Set target name override for SSL host name checking, only takes effect when "secure" is True.
+         * Note: this value is passed to grpc.ssl_target_name_override
+         * @param serverName path of server.pem
+         * @return <code>Builder</code>
+         */
+        public Builder withServerName(@NonNull String serverName) {
+            super.withServerName(serverName);
             return this;
         }
 
@@ -215,6 +280,8 @@ public class MultiConnectParam {
          * @return {@link MultiConnectParam}
          */
         public MultiConnectParam build() throws ParamException {
+            super.verify();
+
             if (CollectionUtils.isEmpty(hosts)) {
                 throw new ParamException("Server addresses is empty!");
             }
@@ -241,36 +308,7 @@ public class MultiConnectParam {
             }
             this.withHosts(hostAddress);
 
-            if (keepAliveTimeMs <= 0L) {
-                throw new ParamException("Keep alive time must be positive!");
-            }
-
-            if (connectTimeoutMs <= 0L) {
-                throw new ParamException("Connect timeout must be positive!");
-            }
-
-            if (keepAliveTimeoutMs <= 0L) {
-                throw new ParamException("Keep alive timeout must be positive!");
-            }
-
-            if (idleTimeoutMs <= 0L) {
-                throw new ParamException("Idle timeout must be positive!");
-            }
-
             return new MultiConnectParam(this);
         }
     }
-
-    /**
-     * Constructs a <code>String</code> by {@link ConnectParam} instance.
-     *
-     * @return <code>String</code>
-     */
-    @Override
-    public String toString() {
-        final StringBuffer sb = new StringBuffer("MultiConnectParam{");
-        sb.append("hosts=").append(hosts);
-        sb.append('}');
-        return sb.toString();
-    }
 }

+ 4 - 9
src/main/java/io/milvus/param/ParamUtils.java

@@ -3,7 +3,6 @@ package io.milvus.param;
 import com.alibaba.fastjson.JSONObject;
 import com.google.common.collect.Lists;
 import com.google.protobuf.ByteString;
-import io.grpc.StatusRuntimeException;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.common.utils.JacksonUtils;
 import io.milvus.exception.IllegalResponseException;
@@ -15,19 +14,15 @@ import io.milvus.param.dml.QueryParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.response.DescCollResponseWrapper;
 import lombok.Builder;
-import lombok.Data;
 import lombok.Getter;
 import lombok.NonNull;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.collections4.MapUtils;
 import org.apache.commons.lang3.StringUtils;
-import org.apache.commons.lang3.tuple.Pair;
 
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.*;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 /**
@@ -256,6 +251,7 @@ public class ParamUtils {
         for (FieldType fieldType : wrapper.getFields()) {
             if (fieldType.isPartitionKey()) {
                 isPartitionKeyEnabled = true;
+                break;
             }
         }
         if (isPartitionKeyEnabled) {
@@ -544,7 +540,6 @@ public class ParamUtils {
         add(DataType.BinaryVector);
     }};
 
-    @SuppressWarnings("unchecked")
     private static FieldData genFieldData(String fieldName, DataType dataType, List<?> objects) {
         return genFieldData(fieldName, dataType, objects, Boolean.FALSE);
     }
@@ -720,8 +715,8 @@ public class ParamUtils {
     @Builder
     @Getter
     public static class InsertDataInfo {
-        private String fieldName;
-        private DataType dataType;
-        private LinkedList<Object> data;
+        private final String fieldName;
+        private final DataType dataType;
+        private final LinkedList<Object> data;
     }
 }

+ 3 - 1
src/test/java/io/milvus/client/MilvusMultiClientDockerTest.java

@@ -126,7 +126,9 @@ class MilvusMultiClientDockerTest {
     public static void setUp() {
         startDockerContainer();
 
-        MultiConnectParam connectParam = multiConnectParamBuilder().withAuthorization("root", "Milvus").build();
+        MultiConnectParam connectParam = multiConnectParamBuilder()
+                .withAuthorization("root", "Milvus")
+                .build();
         client = new MilvusMultiServiceClient(connectParam);
 //        TimeUnit.SECONDS.sleep(10);
         generator = new RandomStringGenerator.Builder().withinRange('a', 'z').build();

+ 1 - 1
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -68,6 +68,7 @@ class MilvusServiceClientTest {
         return new MilvusServiceClient(connectParam);
     }
 
+    @SuppressWarnings("unchecked")
     private <T, P> void invokeFunc(Method testFunc, MilvusServiceClient client, T param, int ret, boolean equalRet) {
         try {
             R<P> resp = (R<P>) testFunc.invoke(client, param);
@@ -83,7 +84,6 @@ class MilvusServiceClientTest {
         }
     }
 
-    @SuppressWarnings("unchecked")
     private <T, P> void testFuncByName(String funcName, T param) {
         // start mock server
         MockMilvusServer server = startServer();