ClientUtils.java 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v2.utils;
  20. import io.grpc.ManagedChannel;
  21. import io.grpc.ManagedChannelBuilder;
  22. import io.grpc.Metadata;
  23. import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
  24. import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
  25. import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
  26. import io.grpc.stub.MetadataUtils;
  27. import io.milvus.grpc.*;
  28. import io.milvus.v2.client.ConnectConfig;
  29. import org.apache.commons.lang3.StringUtils;
  30. import org.slf4j.Logger;
  31. import org.slf4j.LoggerFactory;
  32. import java.io.File;
  33. import java.io.IOException;
  34. import java.nio.charset.StandardCharsets;
  35. import java.util.Base64;
  36. import java.util.concurrent.TimeUnit;
  37. public class ClientUtils {
  38. Logger logger = LoggerFactory.getLogger(ClientUtils.class);
  39. RpcUtils rpcUtils = new RpcUtils();
  40. public ManagedChannel getChannel(ConnectConfig connectConfig){
  41. ManagedChannel channel = null;
  42. Metadata metadata = new Metadata();
  43. if (connectConfig.getAuthorization() != null) {
  44. metadata.put(Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER), Base64.getEncoder().encodeToString(connectConfig.getAuthorization().getBytes(StandardCharsets.UTF_8)));
  45. }
  46. if (StringUtils.isNotEmpty(connectConfig.getDbName())) {
  47. metadata.put(Metadata.Key.of("dbname", Metadata.ASCII_STRING_MARSHALLER), connectConfig.getDbName());
  48. }
  49. try {
  50. if (StringUtils.isNotEmpty(connectConfig.getServerPemPath())) {
  51. // one-way tls
  52. SslContext sslContext = GrpcSslContexts.forClient()
  53. .trustManager(new File(connectConfig.getServerPemPath()))
  54. .build();
  55. NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort())
  56. .overrideAuthority(connectConfig.getServerName())
  57. .sslContext(sslContext)
  58. .maxInboundMessageSize(Integer.MAX_VALUE)
  59. .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
  60. .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
  61. .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
  62. .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
  63. .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
  64. if(connectConfig.isSecure()){
  65. builder.useTransportSecurity();
  66. }
  67. channel = builder.build();
  68. } else if (StringUtils.isNotEmpty(connectConfig.getClientPemPath())
  69. && StringUtils.isNotEmpty(connectConfig.getClientKeyPath())
  70. && StringUtils.isNotEmpty(connectConfig.getCaPemPath())) {
  71. // tow-way tls
  72. SslContext sslContext = GrpcSslContexts.forClient()
  73. .trustManager(new File(connectConfig.getCaPemPath()))
  74. .keyManager(new File(connectConfig.getClientPemPath()), new File(connectConfig.getClientKeyPath()))
  75. .build();
  76. NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort())
  77. .sslContext(sslContext)
  78. .maxInboundMessageSize(Integer.MAX_VALUE)
  79. .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
  80. .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
  81. .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
  82. .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
  83. .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
  84. if (connectConfig.getSecure()) {
  85. builder.useTransportSecurity();
  86. }
  87. if (StringUtils.isNotEmpty(connectConfig.getServerName())) {
  88. builder.overrideAuthority(connectConfig.getServerName());
  89. }
  90. channel = builder.build();
  91. } else {
  92. // no tls
  93. ManagedChannelBuilder<?> builder = ManagedChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort())
  94. .usePlaintext()
  95. .maxInboundMessageSize(Integer.MAX_VALUE)
  96. .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS)
  97. .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS)
  98. .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls())
  99. .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS)
  100. .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata));
  101. if(connectConfig.isSecure()){
  102. builder.useTransportSecurity();
  103. }
  104. channel = builder.build();
  105. }
  106. } catch (IOException e) {
  107. logger.error("Failed to open credentials file, error:{}\n", e.getMessage());
  108. }
  109. assert channel != null;
  110. return channel;
  111. }
  112. public void checkDatabaseExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName) {
  113. String title = String.format("Check database %s exist", dbName);
  114. ListDatabasesRequest listDatabasesRequest = ListDatabasesRequest.newBuilder().build();
  115. ListDatabasesResponse response = blockingStub.listDatabases(listDatabasesRequest);
  116. rpcUtils.handleResponse(title, response.getStatus());
  117. if (!response.getDbNamesList().contains(dbName)) {
  118. throw new IllegalArgumentException("Database " + dbName + " not exist");
  119. }
  120. }
  121. public String getServerVersion(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub) {
  122. GetVersionResponse response = blockingStub.getVersion(GetVersionRequest.newBuilder().build());
  123. rpcUtils.handleResponse("Get server version", response.getStatus());
  124. return response.getVersion();
  125. }
  126. }