Selaa lähdekoodia

support connectType if use oss-bucket (#1590)

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>
xushuang.hu 3 viikkoa sitten
vanhempi
commit
b6e579d762

+ 4 - 4
examples/src/main/java/io/milvus/v1/BulkWriterExample.java

@@ -448,7 +448,7 @@ public class BulkWriterExample {
 
     private static StorageConnectParam buildStorageConnectParam() {
         StorageConnectParam connectParam;
-        if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
+        if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) {
             String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
                     ";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
             connectParam = AzureConnectParam.newBuilder()
@@ -541,11 +541,11 @@ public class BulkWriterExample {
     }
 
     private void callCloudImport(List<List<String>> batchFiles, String collectionName, String partitionName) throws InterruptedException {
-        String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE
+        String objectUrl = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())
                 ? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles))
                 : StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION);
-        String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
-        String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
+        String accessKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
+        String secretKey = CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName()) ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
 
         System.out.println("\n===================== call cloudImport ====================");
         List<String> objectUrls = Lists.newArrayList(objectUrl);

+ 2 - 0
examples/src/main/java/io/milvus/v2/StageFileManagerExample.java

@@ -21,6 +21,7 @@ package io.milvus.v2;
 import com.google.gson.Gson;
 import io.milvus.bulkwriter.StageFileManager;
 import io.milvus.bulkwriter.StageFileManagerParam;
+import io.milvus.bulkwriter.common.clientenum.ConnectType;
 import io.milvus.bulkwriter.model.UploadFilesResult;
 import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
 
@@ -35,6 +36,7 @@ public class StageFileManagerExample {
                 .withCloudEndpoint("https://api.cloud.zilliz.com")
                 .withApiKey("_api_key_for_cluster_org_")
                 .withStageName("_stage_name_for_project_")
+                .withConnectType(ConnectType.AUTO)
                 .build();
         stageFileManager = new StageFileManager(stageFileManagerParam);
     }

+ 1 - 1
examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java

@@ -392,7 +392,7 @@ public class BulkWriterRemoteExample {
 
     private static StorageConnectParam buildStorageConnectParam() {
         StorageConnectParam connectParam;
-        if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
+        if (CloudStorage.isAzCloud(StorageConsts.cloudStorage.getCloudName())) {
             String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
                     ";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
             connectParam = AzureConnectParam.newBuilder()

+ 2 - 1
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageBulkWriter.java

@@ -21,6 +21,7 @@ package io.milvus.bulkwriter;
 
 import com.google.common.collect.Lists;
 import com.google.gson.JsonObject;
+import io.milvus.bulkwriter.common.clientenum.ConnectType;
 import io.milvus.bulkwriter.model.UploadFilesResult;
 import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
 import io.milvus.common.utils.ExceptionUtils;
@@ -63,7 +64,7 @@ public class StageBulkWriter extends LocalBulkWriter {
     private StageFileManager initStageFileManagerParams(StageBulkWriterParam bulkWriterParam) throws IOException {
         StageFileManagerParam stageFileManagerParam = StageFileManagerParam.newBuilder()
                 .withCloudEndpoint(bulkWriterParam.getCloudEndpoint()).withApiKey(bulkWriterParam.getApiKey())
-                .withStageName(bulkWriterParam.getStageName())
+                .withStageName(bulkWriterParam.getStageName()).withConnectType(ConnectType.AUTO)
                 .build();
         return new StageFileManager(stageFileManagerParam);
     }

+ 11 - 3
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManager.java

@@ -20,10 +20,12 @@
 package io.milvus.bulkwriter;
 
 import com.google.gson.Gson;
+import io.milvus.bulkwriter.common.clientenum.ConnectType;
 import io.milvus.bulkwriter.common.utils.FileUtils;
 import io.milvus.bulkwriter.model.UploadFilesResult;
 import io.milvus.bulkwriter.request.stage.ApplyStageRequest;
 import io.milvus.bulkwriter.request.stage.UploadFilesRequest;
+import io.milvus.bulkwriter.resolver.EndpointResolver;
 import io.milvus.bulkwriter.response.ApplyStageResponse;
 import io.milvus.bulkwriter.restful.DataStageUtils;
 import io.milvus.bulkwriter.storage.StorageClient;
@@ -54,6 +56,7 @@ public class StageFileManager {
     private final String cloudEndpoint;
     private final String apiKey;
     private final String stageName;
+    private final ConnectType connectType;
     private final ExecutorService executor;
 
     private StorageClient storageClient;
@@ -63,7 +66,8 @@ public class StageFileManager {
         this.cloudEndpoint = stageWriterParam.getCloudEndpoint();
         this.apiKey = stageWriterParam.getApiKey();
         this.stageName = stageWriterParam.getStageName();
-        this.executor = Executors.newFixedThreadPool(20);
+        this.connectType = stageWriterParam.getConnectType();
+        this.executor = Executors.newFixedThreadPool(10);
     }
 
     /**
@@ -138,7 +142,7 @@ public class StageFileManager {
     public void shutdownGracefully() {
         executor.shutdown();
         try {
-            if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
+            if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
                 logger.warn("Executor didn't terminate in time, forcing shutdown...");
                 executor.shutdownNow();
             }
@@ -168,9 +172,11 @@ public class StageFileManager {
         applyStageResponse = new Gson().fromJson(result, ApplyStageResponse.class);
         logger.info("stage info refreshed");
 
+        String endpoint = EndpointResolver.resolveEndpoint(applyStageResponse.getEndpoint(), applyStageResponse.getCloud(),
+                applyStageResponse.getRegion(), connectType);
         storageClient = MinioStorageClient.getStorageClient(
                 applyStageResponse.getCloud(),
-                applyStageResponse.getEndpoint(),
+                endpoint,
                 applyStageResponse.getCredentials().getTmpAK(),
                 applyStageResponse.getCredentials().getTmpSK(),
                 applyStageResponse.getCredentials().getSessionToken(),
@@ -235,6 +241,8 @@ public class StageFileManager {
         while (attempt < maxRetries) {
             try {
                 return callable.call();
+            } catch (RuntimeException e) {
+                throw e;
             } catch (Exception e) {
                 attempt++;
                 refreshStageAndClient(stagePath);

+ 16 - 0
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/StageFileManagerParam.java

@@ -19,6 +19,7 @@
 
 package io.milvus.bulkwriter;
 
+import io.milvus.bulkwriter.common.clientenum.ConnectType;
 import io.milvus.exception.ParamException;
 import io.milvus.param.ParamUtils;
 import lombok.Getter;
@@ -35,11 +36,13 @@ public class StageFileManagerParam {
     private final String cloudEndpoint;
     private final String apiKey;
     private final String stageName;
+    private final ConnectType connectType;
 
     private StageFileManagerParam(@NonNull Builder builder) {
         this.cloudEndpoint = builder.cloudEndpoint;
         this.apiKey = builder.apiKey;
         this.stageName = builder.stageName;
+        this.connectType = builder.connectType;
     }
 
     public static Builder newBuilder() {
@@ -56,6 +59,8 @@ public class StageFileManagerParam {
 
         private String stageName;
 
+        private ConnectType connectType = ConnectType.AUTO;
+
         private Builder() {
         }
 
@@ -79,6 +84,17 @@ public class StageFileManagerParam {
             return this;
         }
 
+        /**
+         * Current value is mainly for Aliyun OSS buckets, default is Auto.
+         * In the default case, if the OSS bucket is reachable via the internal endpoint, the internal endpoint will be used;
+         * otherwise, the public endpoint will be used.
+         * You can also force the use of either the internal or public endpoint.
+         */
+        public Builder withConnectType(@NotNull ConnectType connectType) {
+            this.connectType = connectType;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link StageFileManagerParam} instance.
          *

+ 38 - 2
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/CloudStorage.java

@@ -22,14 +22,25 @@ package io.milvus.bulkwriter.common.clientenum;
 import io.milvus.exception.ParamException;
 import lombok.Getter;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.util.Lists;
+
+import java.util.List;
 
 public enum CloudStorage {
     MINIO("minio","%s", "minioAddress"),
     AWS("aws","s3.amazonaws.com", null),
     GCP("gcp" ,"storage.googleapis.com", null),
+
+    AZ("az" ,"%s.blob.core.windows.net", "accountName"),
     AZURE("azure" ,"%s.blob.core.windows.net", "accountName"),
+
     ALI("ali","oss-%s.aliyuncs.com", "region"),
-    TC("tc","cos.%s.myqcloud.com", "region")
+    ALIYUN("aliyun","oss-%s.aliyuncs.com", "region"),
+    ALIBABA("alibaba","oss-%s.aliyuncs.com", "region"),
+    ALICLOU("alicloud","oss-%s.aliyuncs.com", "region"),
+
+    TC("tc","cos.%s.myqcloud.com", "region"),
+    TENCENT("tencent","cos.%s.myqcloud.com", "region")
     ;
 
     @Getter
@@ -45,6 +56,27 @@ public enum CloudStorage {
         this.replace = replace;
     }
 
+    public static boolean isAliCloud(String cloudName) {
+        List<CloudStorage> aliCloudStorages = Lists.newArrayList(
+                CloudStorage.ALI, CloudStorage.ALIYUN, CloudStorage.ALIBABA, CloudStorage.ALICLOU
+        );
+        return aliCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
+    }
+
+    public static boolean isTcCloud(String cloudName) {
+        List<CloudStorage> tcCloudStorages = Lists.newArrayList(
+                CloudStorage.TC, CloudStorage.TENCENT
+        );
+        return tcCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
+    }
+
+    public static boolean isAzCloud(String cloudName) {
+        List<CloudStorage> azCloudStorages = Lists.newArrayList(
+                CloudStorage.AZ, CloudStorage.AZURE
+        );
+        return azCloudStorages.stream().anyMatch(e -> e.getCloudName().equalsIgnoreCase(cloudName));
+    }
+
     public static CloudStorage getCloudStorage(String cloudName) {
         for (CloudStorage cloudStorage : values()) {
             if (cloudStorage.getCloudName().equals(cloudName)) {
@@ -71,8 +103,12 @@ public enum CloudStorage {
             case GCP:
                 return String.format("https://storage.cloud.google.com/%s/%s", bucketName, commonPrefix);
             case TC:
+            case TENCENT:
                 return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, commonPrefix);
             case ALI:
+            case ALICLOU:
+            case ALIBABA:
+            case ALIYUN:
                 return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, commonPrefix);
             default:
                 throw new ParamException("no support others remote storage address");
@@ -80,7 +116,7 @@ public enum CloudStorage {
     }
 
     public String getAzureObjectUrl(String accountName, String containerName, String commonPrefix) {
-        if (this == CloudStorage.AZURE) {
+        if (CloudStorage.isAzCloud(this.getCloudName())) {
             return String.format("https://%s.blob.core.windows.net/%s/%s", accountName, containerName, commonPrefix);
         }
         throw new ParamException("no support others remote storage address");

+ 7 - 0
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/ConnectType.java

@@ -0,0 +1,7 @@
+package io.milvus.bulkwriter.common.clientenum;
+
+public enum ConnectType {
+    AUTO,
+    INTERNAL,
+    PUBLIC
+}

+ 0 - 22
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/StorageUtils.java

@@ -1,22 +0,0 @@
-package io.milvus.bulkwriter.common.utils;
-
-import io.milvus.bulkwriter.common.clientenum.CloudStorage;
-import io.milvus.exception.ParamException;
-
-public class StorageUtils {
-    public static String getObjectUrl(String cloudName, String bucketName, String objectPath, String region) {
-        CloudStorage cloudStorage = CloudStorage.getCloudStorage(cloudName);
-        switch (cloudStorage) {
-            case AWS:
-                return String.format("https://s3.%s.amazonaws.com/%s/%s", region, bucketName, objectPath);
-            case GCP:
-                return String.format("https://storage.cloud.google.com/%s/%s", bucketName, objectPath);
-            case TC:
-                return String.format("https://%s.cos.%s.myqcloud.com/%s", bucketName, region, objectPath);
-            case ALI:
-                return String.format("https://%s.oss-%s.aliyuncs.com/%s", bucketName, region, objectPath);
-            default:
-                throw new ParamException("no support others remote storage address");
-        }
-    }
-}

+ 71 - 0
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/resolver/EndpointResolver.java

@@ -0,0 +1,71 @@
+package io.milvus.bulkwriter.resolver;
+
+import io.milvus.bulkwriter.common.clientenum.CloudStorage;
+import io.milvus.bulkwriter.common.clientenum.ConnectType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.util.concurrent.TimeUnit;
+
+public class EndpointResolver {
+    private static final Logger logger = LoggerFactory.getLogger(EndpointResolver.class);
+
+    public static String resolveEndpoint(String defaultEndpoint, String cloud, String region, ConnectType connectType) {
+        logger.info("Start resolving endpoint, cloud:{}, region:{}, connectType:{}", cloud, region, connectType);
+        if (CloudStorage.isAliCloud(cloud)) {
+            defaultEndpoint = resolveOssEndpoint(region, connectType);
+        }
+        logger.info("Resolved endpoint: {}, reachable check passed", defaultEndpoint);
+        return defaultEndpoint;
+    }
+
+    private static String resolveOssEndpoint(String region, ConnectType connectType) {
+        String internalEndpoint = String.format("oss-%s-internal.aliyuncs.com", region);
+        String publicEndpoint = String.format("oss-%s.aliyuncs.com", region);
+
+        switch (connectType) {
+            case INTERNAL:
+                logger.info("Forced INTERNAL endpoint selected: {}", internalEndpoint);
+                checkEndpointReachable(internalEndpoint, true);
+                return internalEndpoint;
+            case PUBLIC:
+                logger.info("Forced PUBLIC endpoint selected: {}", publicEndpoint);
+                checkEndpointReachable(publicEndpoint, true);
+                return publicEndpoint;
+            case AUTO:
+            default:
+                if (checkEndpointReachable(internalEndpoint, false)) {
+                    logger.info("AUTO mode: internal endpoint reachable, using {}", internalEndpoint);
+                    return internalEndpoint;
+                } else {
+                    logger.warn("AUTO mode: internal endpoint not reachable, fallback to public endpoint {}", publicEndpoint);
+                    checkEndpointReachable(publicEndpoint, true);
+                    return publicEndpoint;
+                }
+        }
+    }
+
+    private static boolean checkEndpointReachable(String endpoint, boolean printError) {
+        try {
+            String httpEndpoint = String.format("https://%s", endpoint);
+            URL url = new URL(httpEndpoint);
+            HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+            conn.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(5));
+            conn.setReadTimeout((int) TimeUnit.SECONDS.toMillis(5));
+            conn.setRequestMethod("HEAD");
+            int code = conn.getResponseCode();
+            logger.debug("Checked endpoint {}, response code={}", endpoint, code);
+            return code >= 200 && code < 400;
+        } catch (Exception e) {
+            if (printError) {
+                logger.error("Endpoint {} not reachable, throwing exception", endpoint, e);
+                throw new RuntimeException(e.getMessage());
+            } else {
+                logger.warn("Endpoint {} not reachable, will fallback if needed", endpoint);
+                return false;
+            }
+        }
+    }
+}

+ 1 - 1
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/storage/client/MinioStorageClient.java

@@ -84,7 +84,7 @@ public class MinioStorageClient extends MinioAsyncClient implements StorageClien
         }
 
         MinioAsyncClient minioClient = minioClientBuilder.build();
-        if (CloudStorage.TC.getCloudName().equals(cloudName)) {
+        if (CloudStorage.isTcCloud(cloudName)) {
             minioClient.enableVirtualStyleEndpoint();
         }