123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979 |
- /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
- package io.milvus.v1;
- import com.fasterxml.jackson.annotation.JsonProperty;
- import com.fasterxml.jackson.dataformat.csv.CsvMapper;
- import com.fasterxml.jackson.dataformat.csv.CsvSchema;
- import com.google.common.collect.Lists;
- import com.google.gson.Gson;
- import com.google.gson.JsonElement;
- import com.google.gson.JsonObject;
- import com.google.gson.reflect.TypeToken;
- import io.milvus.bulkwriter.BulkWriter;
- import io.milvus.bulkwriter.CloudImport;
- import io.milvus.bulkwriter.LocalBulkWriter;
- import io.milvus.bulkwriter.LocalBulkWriterParam;
- import io.milvus.bulkwriter.RemoteBulkWriter;
- import io.milvus.bulkwriter.RemoteBulkWriterParam;
- import io.milvus.bulkwriter.common.clientenum.BulkFileType;
- import io.milvus.bulkwriter.common.clientenum.CloudStorage;
- import io.milvus.bulkwriter.common.utils.GeneratorUtils;
- import io.milvus.bulkwriter.common.utils.ImportUtils;
- import io.milvus.bulkwriter.common.utils.ParquetReaderUtils;
- import io.milvus.bulkwriter.connect.AzureConnectParam;
- import io.milvus.bulkwriter.connect.S3ConnectParam;
- import io.milvus.bulkwriter.connect.StorageConnectParam;
- import io.milvus.bulkwriter.request.BulkImportRequest;
- import io.milvus.bulkwriter.request.GetImportProgressRequest;
- import io.milvus.bulkwriter.request.ListImportJobsRequest;
- import io.milvus.client.MilvusClient;
- import io.milvus.client.MilvusServiceClient;
- import io.milvus.common.utils.ExceptionUtils;
- import io.milvus.grpc.DataType;
- import io.milvus.grpc.GetCollectionStatisticsResponse;
- import io.milvus.grpc.GetImportStateResponse;
- import io.milvus.grpc.ImportResponse;
- import io.milvus.grpc.ImportState;
- import io.milvus.grpc.KeyValuePair;
- import io.milvus.grpc.QueryResults;
- import io.milvus.param.ConnectParam;
- import io.milvus.param.IndexType;
- import io.milvus.param.MetricType;
- import io.milvus.param.R;
- import io.milvus.param.RpcStatus;
- import io.milvus.param.bulkinsert.BulkInsertParam;
- import io.milvus.param.bulkinsert.GetBulkInsertStateParam;
- import io.milvus.param.collection.CollectionSchemaParam;
- import io.milvus.param.collection.CreateCollectionParam;
- import io.milvus.param.collection.DropCollectionParam;
- import io.milvus.param.collection.FieldType;
- import io.milvus.param.collection.FlushParam;
- import io.milvus.param.collection.GetCollectionStatisticsParam;
- import io.milvus.param.collection.HasCollectionParam;
- import io.milvus.param.collection.LoadCollectionParam;
- import io.milvus.param.dml.QueryParam;
- import io.milvus.param.index.CreateIndexParam;
- import io.milvus.response.GetCollStatResponseWrapper;
- import io.milvus.response.QueryResultsWrapper;
- import org.apache.avro.generic.GenericData;
- import org.apache.commons.lang3.StringUtils;
- import org.apache.http.util.Asserts;
- import java.io.File;
- import java.io.IOException;
- import java.net.URL;
- import java.nio.ByteBuffer;
- import java.util.ArrayList;
- import java.util.Iterator;
- import java.util.List;
- import java.util.Optional;
- import java.util.concurrent.TimeUnit;
- public class BulkWriterExample {
- // milvus
- public static final String HOST = "127.0.0.1";
- public static final Integer PORT = 19530;
- public static final String USER_NAME = "user.name";
- public static final String PASSWORD = "password";
- private static final Gson GSON_INSTANCE = new Gson();
- private static final List<Integer> QUERY_IDS = Lists.newArrayList(100, 5000);
- /**
- * If you need to transfer the files generated by bulkWriter to the corresponding remote storage (AWS S3, GCP GCS, Azure Blob, Aliyun OSS, Tencent Cloud TOS),
- * you need to configure it accordingly; Otherwise, you can ignore it.
- */
- public static class StorageConsts {
- public static final CloudStorage cloudStorage = CloudStorage.MINIO;
- /**
- * If using remote storage such as AWS S3, GCP GCS, Aliyun OSS, Tencent Cloud TOS, Minio
- * please configure the following parameters.
- */
- public static final String STORAGE_ENDPOINT = cloudStorage.getEndpoint("http://127.0.0.1:9000");
- public static final String STORAGE_BUCKET = "a-bucket"; // default bucket name of MinIO/Milvus standalone
- public static final String STORAGE_ACCESS_KEY = "minioadmin"; // default ak of MinIO/Milvus standalone
- public static final String STORAGE_SECRET_KEY = "minioadmin"; // default sk of MinIO/Milvus standalone
- /**
- * if using remote storage, please configure the parameter
- * if using local storage such as Local Minio, please set this parameter to empty.
- */
- public static final String STORAGE_REGION = "";
- /**
- * If using remote storage such as Azure Blob
- * please configure the following parameters.
- */
- public static final String AZURE_CONTAINER_NAME = "azure.container.name";
- public static final String AZURE_ACCOUNT_NAME = "azure.account.name";
- public static final String AZURE_ACCOUNT_KEY = "azure.account.key";
- }
- /**
- * If you have used remoteBulkWriter to generate remote data and want to import data using the Import interface on Zilliz Cloud after generation,
- * you don't need to configure the following object-related parameters (OBJECT_URL, OBJECT_ACCESS_KEY, OBJECT_SECRET_KEY). You can call the callCloudImport method, as the internal logic has been encapsulated for you.
- * <p>
- * If you already have data stored in remote storage (not generated through remoteBulkWriter), and you want to invoke the Import interface on Zilliz Cloud to import data,
- * you need to configure the following parameters and then follow the exampleCloudBulkInsert method.
- * <p>
- * If you do not need to import data through the Import interface on Zilliz Cloud, you can ignore this.
- */
- public static class CloudImportConsts {
- /**
- * The value of the URL is fixed.
- * For overseas regions, it is: https://api.cloud.zilliz.com
- * For regions in China, it is: https://api.cloud.zilliz.com.cn
- */
- public static final String CLOUD_ENDPOINT = "https://api.cloud.zilliz.com";
- public static final String API_KEY = "_api_key_for_cluster_org_";
- public static final String CLUSTER_ID = "_your_cloud_cluster_id_";
- public static final String COLLECTION_NAME = "_collection_name_on_the_cluster_id_";
- // If partition_name is not specified, use ""
- public static final String PARTITION_NAME = "_partition_name_on_the_collection_";
- /**
- * Please provide the complete URL for the file or folder you want to import, similar to https://bucket-name.s3.region-code.amazonaws.com/object-name.
- * For more details, you can refer to https://docs.zilliz.com/docs/import-data-on-web-ui.
- */
- public static final String OBJECT_URL = "_your_storage_object_url_";
- public static final String OBJECT_ACCESS_KEY = "_your_storage_access_key_";
- public static final String OBJECT_SECRET_KEY = "_your_storage_secret_key_";
- }
- private static final String SIMPLE_COLLECTION_NAME = "java_sdk_bulkwriter_simple_v1";
- private static final String ALL_TYPES_COLLECTION_NAME = "java_sdk_bulkwriter_all_v1";
- private static final Integer DIM = 512;
- private static final Integer ARRAY_CAPACITY = 10;
- private MilvusClient milvusClient;
- public static void main(String[] args) throws Exception {
- BulkWriterExample exampleBulkWriter = new BulkWriterExample();
- exampleBulkWriter.createConnection();
- List<BulkFileType> fileTypes = Lists.newArrayList(
- BulkFileType.PARQUET
- );
- exampleSimpleCollection(exampleBulkWriter, fileTypes);
- exampleAllTypesCollectionRemote(exampleBulkWriter, fileTypes);
- // to call cloud import api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
- // exampleCloudImport();
- }
- private void createConnection() {
- System.out.println("\nCreate connection...");
- ConnectParam connectParam = ConnectParam.newBuilder()
- .withHost(HOST)
- .withPort(PORT)
- .withAuthorization(USER_NAME, PASSWORD)
- .build();
- milvusClient = new MilvusServiceClient(connectParam);
- System.out.println("\nConnected");
- }
- private static void exampleSimpleCollection(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
- CollectionSchemaParam collectionSchema = exampleBulkWriter.buildSimpleSchema();
- exampleBulkWriter.createCollection(SIMPLE_COLLECTION_NAME, collectionSchema, false);
- for (BulkFileType fileType : fileTypes) {
- localWriter(collectionSchema, fileType);
- }
- for (BulkFileType fileType : fileTypes) {
- remoteWriter(collectionSchema, fileType);
- }
- // parallel append
- parallelAppend(collectionSchema);
- }
- private static void exampleAllTypesCollectionRemote(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
- // 4 types vectors + all scalar types + dynamic field enabled, use bulkInsert interface
- for (BulkFileType fileType : fileTypes) {
- CollectionSchemaParam collectionSchema = buildAllTypesSchema();
- List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(collectionSchema, fileType);
- exampleBulkWriter.callBulkInsert(collectionSchema, batchFiles);
- exampleBulkWriter.retrieveImportData();
- }
- // // 4 types vectors + all scalar types + dynamic field enabled, use cloud import api.
- // // You need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
- // for (BulkFileType fileType : fileTypes) {
- // CollectionSchemaParam collectionSchema = buildAllTypesSchema();
- // List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(collectionSchema, fileType);
- // exampleBulkWriter.createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false);
- // exampleBulkWriter.callCloudImport(batchFiles, ALL_TYPES_COLLECTION_NAME, StringUtils.EMPTY);
- // exampleBulkWriter.retrieveImportData();
- // }
- }
- private static void localWriter(CollectionSchemaParam collectionSchema, BulkFileType fileType) throws Exception {
- System.out.printf("\n===================== local writer (%s) ====================%n", fileType.name());
- LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder()
- .withCollectionSchema(collectionSchema)
- .withLocalPath("/tmp/bulk_writer")
- .withFileType(fileType)
- .withChunkSize(128 * 1024 * 1024)
- .build();
- try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) {
- // read data from csv
- readCsvSampleData("data/train_embeddings.csv", localBulkWriter);
- // append rows
- for (int i = 0; i < 100000; i++) {
- JsonObject row = new JsonObject();
- row.addProperty("path", "path_" + i);
- row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM)));
- row.addProperty("label", "label_" + i);
- localBulkWriter.appendRow(row);
- }
- System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount());
- System.out.printf("%s rows in buffer not flushed%n", localBulkWriter.getBufferRowCount());
- localBulkWriter.commit(false);
- List<List<String>> batchFiles = localBulkWriter.getBatchFiles();
- System.out.printf("Local writer done! output local files: %s%n", batchFiles);
- } catch (Exception e) {
- System.out.println("Local writer catch exception: " + e);
- throw e;
- }
- }
- private static void remoteWriter(CollectionSchemaParam collectionSchema, BulkFileType fileType) throws Exception {
- System.out.printf("\n===================== remote writer (%s) ====================%n", fileType.name());
- try (RemoteBulkWriter remoteBulkWriter = buildRemoteBulkWriter(collectionSchema, fileType)) {
- // read data from csv
- readCsvSampleData("data/train_embeddings.csv", remoteBulkWriter);
- // append rows
- for (int i = 0; i < 100000; i++) {
- JsonObject row = new JsonObject();
- row.addProperty("path", "path_" + i);
- row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM)));
- row.addProperty("label", "label_" + i);
- remoteBulkWriter.appendRow(row);
- }
- System.out.printf("%s rows appends%n", remoteBulkWriter.getTotalRowCount());
- System.out.printf("%s rows in buffer not flushed%n", remoteBulkWriter.getBufferRowCount());
- remoteBulkWriter.commit(false);
- List<List<String>> batchFiles = remoteBulkWriter.getBatchFiles();
- System.out.printf("Remote writer done! output remote files: %s%n", batchFiles);
- } catch (Exception e) {
- System.out.println("Remote writer catch exception: " + e);
- throw e;
- }
- }
- private static void parallelAppend(CollectionSchemaParam collectionSchema) throws Exception {
- System.out.print("\n===================== parallel append ====================");
- LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder()
- .withCollectionSchema(collectionSchema)
- .withLocalPath("/tmp/bulk_writer")
- .withFileType(BulkFileType.PARQUET)
- .withChunkSize(128 * 1024 * 1024) // 128MB
- .build();
- try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) {
- List<Thread> threads = new ArrayList<>();
- int threadCount = 10;
- int rowsPerThread = 1000;
- for (int i = 0; i < threadCount; ++i) {
- int current = i;
- Thread thread = new Thread(() -> appendRow(localBulkWriter, current * rowsPerThread, (current + 1) * rowsPerThread));
- threads.add(thread);
- thread.start();
- System.out.printf("Thread %s started%n", thread.getName());
- }
- for (Thread thread : threads) {
- thread.join();
- System.out.printf("Thread %s finished%n", thread.getName());
- }
- System.out.println(localBulkWriter.getTotalRowCount() + " rows appends");
- System.out.println(localBulkWriter.getBufferRowCount() + " rows in buffer not flushed");
- localBulkWriter.commit(false);
- System.out.printf("Append finished, %s rows%n", threadCount * rowsPerThread);
- int rowCount = 0;
- List<List<String>> batchFiles = localBulkWriter.getBatchFiles();
- for (List<String> batch : batchFiles) {
- for (String filePath : batch) {
- rowCount += readParquet(filePath);
- }
- }
- Asserts.check(rowCount == threadCount * rowsPerThread, String.format("rowCount %s not equals expected %s", rowCount, threadCount * rowsPerThread));
- System.out.println("Data is correct");
- } catch (Exception e) {
- System.out.println("parallelAppend catch exception: " + e);
- throw e;
- }
- }
- private static long readParquet(String localFilePath) throws Exception {
- final long[] rowCount = {0};
- new ParquetReaderUtils() {
- @Override
- public void readRecord(GenericData.Record record) {
- rowCount[0]++;
- String pathValue = record.get("path").toString();
- String labelValue = record.get("label").toString();
- Asserts.check(pathValue.replace("path_", "").equals(labelValue.replace("label_", "")), String.format("the suffix of %s not equals the suffix of %s", pathValue, labelValue));
- }
- }.readParquet(localFilePath);
- System.out.printf("The file %s contains %s rows. Verify the content...%n", localFilePath, rowCount[0]);
- return rowCount[0];
- }
- private static void appendRow(LocalBulkWriter writer, int begin, int end) {
- try {
- for (int i = begin; i < end; ++i) {
- JsonObject row = new JsonObject();
- row.addProperty("path", "path_" + i);
- row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM)));
- row.addProperty("label", "label_" + i);
- writer.appendRow(row);
- if (i % 100 == 0) {
- System.out.printf("%s inserted %s items%n", Thread.currentThread().getName(), i - begin);
- }
- }
- } catch (Exception e) {
- System.out.println("failed to append row!");
- }
- }
- private List<List<String>> allTypesRemoteWriter(CollectionSchemaParam collectionSchema, BulkFileType fileType) throws Exception {
- System.out.printf("\n===================== all field types (%s) ====================%n", fileType.name());
- try (RemoteBulkWriter remoteBulkWriter = buildRemoteBulkWriter(collectionSchema, fileType)) {
- System.out.println("Append rows");
- int batchCount = 10000;
- for (int i = 0; i < batchCount; ++i) {
- JsonObject rowObject = new JsonObject();
- // scalar field
- rowObject.addProperty("id", i);
- rowObject.addProperty("bool", i % 5 == 0);
- rowObject.addProperty("int8", i % 128);
- rowObject.addProperty("int16", i % 1000);
- rowObject.addProperty("int32", i % 100000);
- rowObject.addProperty("float", i / 3);
- rowObject.addProperty("double", i / 7);
- rowObject.addProperty("varchar", "varchar_" + i);
- rowObject.addProperty("json", String.format("{\"dummy\": %s, \"ok\": \"name_%s\"}", i, i));
- // vector field
- rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateFloatVector(DIM)));
- rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateBinaryVector(DIM).array()));
- rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateFloat16Vector(DIM, false).array()));
- rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateSparseVector()));
- // array field
- rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorBoolValue(10)));
- rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt8Value(10)));
- rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt16Value(10)));
- rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt32Value(10)));
- rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorLongValue(10)));
- rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorVarcharValue(10, 10)));
- rowObject.add("array_float", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorFloatValue(10)));
- rowObject.add("array_double", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorDoubleValue(10)));
- // dynamic fields
- if (collectionSchema.isEnableDynamicField()) {
- rowObject.addProperty("dynamic", "dynamic_" + i);
- }
- if (QUERY_IDS.contains(i)) {
- System.out.println(rowObject);
- }
- remoteBulkWriter.appendRow(rowObject);
- }
- System.out.printf("%s rows appends%n", remoteBulkWriter.getTotalRowCount());
- System.out.printf("%s rows in buffer not flushed%n", remoteBulkWriter.getBufferRowCount());
- System.out.println("Generate data files...");
- remoteBulkWriter.commit(false);
- System.out.printf("Data files have been uploaded: %s%n", remoteBulkWriter.getBatchFiles());
- return remoteBulkWriter.getBatchFiles();
- } catch (Exception e) {
- System.out.println("allTypesRemoteWriter catch exception: " + e);
- throw e;
- }
- }
- private static RemoteBulkWriter buildRemoteBulkWriter(CollectionSchemaParam collectionSchema, BulkFileType fileType) throws IOException {
- StorageConnectParam connectParam = buildStorageConnectParam();
- RemoteBulkWriterParam bulkWriterParam = RemoteBulkWriterParam.newBuilder()
- .withCollectionSchema(collectionSchema)
- .withRemotePath("bulk_data")
- .withFileType(fileType)
- .withChunkSize(512 * 1024 * 1024)
- .withConnectParam(connectParam)
- .build();
- return new RemoteBulkWriter(bulkWriterParam);
- }
- private static StorageConnectParam buildStorageConnectParam() {
- StorageConnectParam connectParam;
- if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
- String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
- ";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
- connectParam = AzureConnectParam.newBuilder()
- .withConnStr(connectionStr)
- .withContainerName(StorageConsts.AZURE_CONTAINER_NAME)
- .build();
- } else {
- connectParam = S3ConnectParam.newBuilder()
- .withEndpoint(StorageConsts.STORAGE_ENDPOINT)
- .withCloudName(StorageConsts.cloudStorage.getCloudName())
- .withBucketName(StorageConsts.STORAGE_BUCKET)
- .withAccessKey(StorageConsts.STORAGE_ACCESS_KEY)
- .withSecretKey(StorageConsts.STORAGE_SECRET_KEY)
- .withRegion(StorageConsts.STORAGE_REGION)
- .build();
- }
- return connectParam;
- }
- private static void readCsvSampleData(String filePath, BulkWriter writer) throws IOException, InterruptedException {
- ClassLoader classLoader = BulkWriterExample.class.getClassLoader();
- URL resourceUrl = classLoader.getResource(filePath);
- filePath = new File(resourceUrl.getFile()).getAbsolutePath();
- CsvMapper csvMapper = new CsvMapper();
- File csvFile = new File(filePath);
- CsvSchema csvSchema = CsvSchema.builder().setUseHeader(true).build();
- Iterator<CsvDataObject> iterator = csvMapper.readerFor(CsvDataObject.class).with(csvSchema).readValues(csvFile);
- while (iterator.hasNext()) {
- CsvDataObject dataObject = iterator.next();
- JsonObject row = new JsonObject();
- row.add("vector", GSON_INSTANCE.toJsonTree(dataObject.toFloatArray()));
- row.addProperty("label", dataObject.getLabel());
- row.addProperty("path", dataObject.getPath());
- writer.appendRow(row);
- }
- }
- private static class CsvDataObject {
- @JsonProperty
- private String vector;
- @JsonProperty
- private String path;
- @JsonProperty
- private String label;
- public String getVector() {
- return vector;
- }
- public String getPath() {
- return path;
- }
- public String getLabel() {
- return label;
- }
- public List<Float> toFloatArray() {
- return GSON_INSTANCE.fromJson(vector, new TypeToken<List<Float>>() {
- }.getType());
- }
- }
- private void callBulkInsert(CollectionSchemaParam collectionSchema, List<List<String>> batchFiles) throws InterruptedException {
- System.out.println("\n===================== call bulkInsert ====================");
- createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, true);
- List<Long> taskIds = new ArrayList<>();
- for (List<String> batch : batchFiles) {
- Long taskId = bulkInsert(batch);
- taskIds.add(taskId);
- System.out.println("Create a bulkInert task, task id: " + taskId);
- }
- while (!taskIds.isEmpty()) {
- Iterator<Long> iterator = taskIds.iterator();
- List<Long> tempTaskIds = new ArrayList<>();
- while (iterator.hasNext()) {
- Long taskId = iterator.next();
- System.out.println("Wait 5 second to check bulkInsert tasks state...");
- TimeUnit.SECONDS.sleep(5);
- GetImportStateResponse bulkInsertState = getBulkInsertState(taskId);
- if (bulkInsertState.getState() == ImportState.ImportFailed
- || bulkInsertState.getState() == ImportState.ImportFailedAndCleaned) {
- List<KeyValuePair> infosList = bulkInsertState.getInfosList();
- Optional<String> failedReasonOptional = infosList.stream().filter(e -> e.getKey().equals("failed_reason"))
- .map(KeyValuePair::getValue).findFirst();
- String failedReson = failedReasonOptional.orElse("");
- System.out.printf("The task %s failed, reason: %s%n", taskId, failedReson);
- } else if (bulkInsertState.getState() == ImportState.ImportCompleted) {
- System.out.printf("The task %s completed%n", taskId);
- } else {
- System.out.printf("The task %s is running, state:%s%n", taskId, bulkInsertState.getState());
- tempTaskIds.add(taskId);
- }
- }
- taskIds = tempTaskIds;
- }
- System.out.println("Collection row number: " + getCollectionStatistics());
- }
- private void callCloudImport(List<List<String>> batchFiles, String collectionName, String partitionName) throws InterruptedException {
- System.out.println("\n===================== call cloudImport ====================");
- String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE
- ? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles))
- : StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION);
- String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
- String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
- BulkImportRequest bulkImportRequest = BulkImportRequest.builder()
- .objectUrl(objectUrl).accessKey(accessKey).secretKey(secretKey)
- .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(collectionName).partitionName(partitionName)
- .build();
- String bulkImportResult = CloudImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, CloudImportConsts.API_KEY, bulkImportRequest);
- JsonObject bulkImportObject = convertDataMap(bulkImportResult);
- String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString();
- System.out.println("Create a cloudImport job, job id: " + jobId);
- while (true) {
- System.out.println("Wait 5 second to check bulkInsert job state...");
- TimeUnit.SECONDS.sleep(5);
- GetImportProgressRequest request = GetImportProgressRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).build();
- String getImportProgressResult = CloudImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, CloudImportConsts.API_KEY, request);
- JsonObject getImportProgressObject = convertDataMap(getImportProgressResult);
- String importProgressState = getImportProgressObject.getAsJsonObject("data").get("state").getAsString();
- String reason = getImportProgressObject.getAsJsonObject("data").get("reason").getAsString();
- String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString();
- if ("Completed".equals(importProgressState)) {
- System.out.printf("The job %s completed%n", jobId);
- break;
- } else if (StringUtils.isNotEmpty(reason)) {
- System.out.printf("The job %s failed or canceled, reason: %s%n", jobId, reason);
- break;
- } else {
- System.out.printf("The job %s is running, progress:%s%n", jobId, progress);
- }
- }
- System.out.println("Collection row number: " + getCollectionStatistics());
- }
- /**
- * @param collectionSchema collection info
- * @param dropIfExist if collection already exist, will drop firstly and then create again
- */
- private void createCollection(String collectionName, CollectionSchemaParam collectionSchema, boolean dropIfExist) {
- System.out.println("\n===================== create collection ====================");
- checkMilvusClientIfExist();
- CreateCollectionParam collectionParam = CreateCollectionParam.newBuilder()
- .withCollectionName(collectionName)
- .withSchema(collectionSchema)
- .build();
- R<Boolean> hasCollection = milvusClient.hasCollection(HasCollectionParam.newBuilder().withCollectionName(collectionName).build());
- if (hasCollection.getData()) {
- if (dropIfExist) {
- milvusClient.dropCollection(DropCollectionParam.newBuilder().withCollectionName(collectionName).build());
- milvusClient.createCollection(collectionParam);
- }
- } else {
- milvusClient.createCollection(collectionParam);
- }
- System.out.printf("Collection %s created%n", collectionName);
- }
- private void retrieveImportData() {
- createIndex();
- System.out.printf("Load collection and query items %s%n", QUERY_IDS);
- loadCollection();
- String expr = String.format("id in %s", QUERY_IDS);
- System.out.println(expr);
- List<QueryResultsWrapper.RowRecord> rowRecords = query(expr, Lists.newArrayList("*"));
- System.out.println("Query results:");
- for (QueryResultsWrapper.RowRecord record : rowRecords) {
- JsonObject rowObject = new JsonObject();
- // scalar field
- rowObject.addProperty("id", (Long)record.get("id"));
- rowObject.addProperty("bool", (Boolean) record.get("bool"));
- rowObject.addProperty("int8", (Integer) record.get("int8"));
- rowObject.addProperty("int16", (Integer) record.get("int16"));
- rowObject.addProperty("int32", (Integer) record.get("int32"));
- rowObject.addProperty("float", (Float) record.get("float"));
- rowObject.addProperty("double", (Double) record.get("double"));
- rowObject.addProperty("varchar", (String) record.get("varchar"));
- rowObject.add("json", (JsonElement) record.get("json"));
- // vector field
- rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(record.get("float_vector")));
- rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(((ByteBuffer)record.get("binary_vector")).array()));
- rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(((ByteBuffer)record.get("float16_vector")).array()));
- rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(record.get("sparse_vector")));
- // array field
- rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(record.get("array_bool")));
- rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(record.get("array_int8")));
- rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(record.get("array_int16")));
- rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(record.get("array_int32")));
- rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(record.get("array_int64")));
- rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(record.get("array_varchar")));
- rowObject.add("array_float", GSON_INSTANCE.toJsonTree(record.get("array_float")));
- rowObject.add("array_double", GSON_INSTANCE.toJsonTree(record.get("array_double")));
- // dynamic field
- rowObject.addProperty("dynamic", (String) record.get("dynamic"));
- System.out.println(rowObject);
- }
- }
- private void createIndex() {
- System.out.println("Create index...");
- checkMilvusClientIfExist();
- R<RpcStatus> response = milvusClient.createIndex(CreateIndexParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .withFieldName("float_vector")
- .withIndexType(IndexType.FLAT)
- .withMetricType(MetricType.L2)
- .withSyncMode(Boolean.TRUE)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- response = milvusClient.createIndex(CreateIndexParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .withFieldName("binary_vector")
- .withIndexType(IndexType.BIN_FLAT)
- .withMetricType(MetricType.HAMMING)
- .withSyncMode(Boolean.TRUE)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- response = milvusClient.createIndex(CreateIndexParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .withFieldName("float16_vector")
- .withIndexType(IndexType.FLAT)
- .withMetricType(MetricType.IP)
- .withSyncMode(Boolean.TRUE)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- response = milvusClient.createIndex(CreateIndexParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .withFieldName("sparse_vector")
- .withIndexType(IndexType.SPARSE_WAND)
- .withMetricType(MetricType.IP)
- .withSyncMode(Boolean.TRUE)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- }
- private R<RpcStatus> loadCollection() {
- System.out.println("Loading Collection...");
- checkMilvusClientIfExist();
- R<RpcStatus> response = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- return response;
- }
- private List<QueryResultsWrapper.RowRecord> query(String expr, List<String> outputFields) {
- System.out.println("========== query() ==========");
- checkMilvusClientIfExist();
- QueryParam test = QueryParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .withExpr(expr)
- .withOutFields(outputFields)
- .build();
- R<QueryResults> response = milvusClient.query(test);
- ExceptionUtils.handleResponseStatus(response);
- QueryResultsWrapper wrapper = new QueryResultsWrapper(response.getData());
- return wrapper.getRowRecords();
- }
- private Long bulkInsert(List<String> batchFiles) {
- System.out.println("========== bulkInsert() ==========");
- checkMilvusClientIfExist();
- R<ImportResponse> response = milvusClient.bulkInsert(BulkInsertParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .withFiles(batchFiles)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- return response.getData().getTasksList().get(0);
- }
- private GetImportStateResponse getBulkInsertState(Long taskId) {
- System.out.println("========== getBulkInsertState() ==========");
- checkMilvusClientIfExist();
- R<GetImportStateResponse> bulkInsertState = milvusClient.getBulkInsertState(GetBulkInsertStateParam.newBuilder()
- .withTask(taskId)
- .build());
- return bulkInsertState.getData();
- }
- private Long getCollectionStatistics() {
- System.out.println("========== getCollectionStatistics() ==========");
- // call flush() to flush the insert buffer to storage,
- // so that the getCollectionStatistics() can get correct number
- checkMilvusClientIfExist();
- milvusClient.flush(FlushParam.newBuilder().addCollectionName(ALL_TYPES_COLLECTION_NAME).build());
- R<GetCollectionStatisticsResponse> response = milvusClient.getCollectionStatistics(
- GetCollectionStatisticsParam.newBuilder()
- .withCollectionName(ALL_TYPES_COLLECTION_NAME)
- .build());
- ExceptionUtils.handleResponseStatus(response);
- GetCollStatResponseWrapper wrapper = new GetCollStatResponseWrapper(response.getData());
- return wrapper.getRowCount();
- }
- private static void exampleCloudImport() {
- System.out.println("\n===================== import files to cloud vectordb ====================");
- BulkImportRequest request = BulkImportRequest.builder()
- .objectUrl(CloudImportConsts.OBJECT_URL).accessKey(CloudImportConsts.OBJECT_ACCESS_KEY).secretKey(CloudImportConsts.OBJECT_SECRET_KEY)
- .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(CloudImportConsts.COLLECTION_NAME).partitionName(CloudImportConsts.PARTITION_NAME)
- .build();
- String bulkImportResult = CloudImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, CloudImportConsts.API_KEY, request);
- System.out.println(bulkImportResult);
- System.out.println("\n===================== get import job progress ====================");
- JsonObject bulkImportObject = convertDataMap(bulkImportResult);
- String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString();
- GetImportProgressRequest getImportProgressRequest = GetImportProgressRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).build();
- String getImportProgressResult = CloudImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, CloudImportConsts.API_KEY, getImportProgressRequest);
- System.out.println(getImportProgressResult);
- System.out.println("\n===================== list import jobs ====================");
- ListImportJobsRequest listImportJobsRequest = ListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).build();
- String listImportJobsResult = CloudImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, CloudImportConsts.API_KEY, listImportJobsRequest);
- System.out.println(listImportJobsResult);
- }
- private CollectionSchemaParam buildSimpleSchema() {
- FieldType fieldType1 = FieldType.newBuilder()
- .withName("id")
- .withDataType(DataType.Int64)
- .withPrimaryKey(true)
- .withAutoID(true)
- .build();
- // vector field
- FieldType fieldType2 = FieldType.newBuilder()
- .withName("vector")
- .withDataType(DataType.FloatVector)
- .withDimension(DIM)
- .build();
- // scalar field
- FieldType fieldType3 = FieldType.newBuilder()
- .withName("path")
- .withDataType(DataType.VarChar)
- .withMaxLength(512)
- .build();
- FieldType fieldType4 = FieldType.newBuilder()
- .withName("label")
- .withDataType(DataType.VarChar)
- .withMaxLength(512)
- .build();
- return CollectionSchemaParam.newBuilder()
- .addFieldType(fieldType1)
- .addFieldType(fieldType2)
- .addFieldType(fieldType3)
- .addFieldType(fieldType4)
- .build();
- }
- private static CollectionSchemaParam buildAllTypesSchema() {
- List<FieldType> fieldTypes = new ArrayList<>();
- // scalar field
- fieldTypes.add(FieldType.newBuilder()
- .withName("id")
- .withDataType(DataType.Int64)
- .withPrimaryKey(true)
- .withAutoID(false)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("bool")
- .withDataType(DataType.Bool)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("int8")
- .withDataType(DataType.Int8)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("int16")
- .withDataType(DataType.Int16)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("int32")
- .withDataType(DataType.Int32)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("float")
- .withDataType(DataType.Float)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("double")
- .withDataType(DataType.Double)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("varchar")
- .withDataType(DataType.VarChar)
- .withMaxLength(512)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("json")
- .withDataType(DataType.JSON)
- .build());
- // vector fields
- fieldTypes.add(FieldType.newBuilder()
- .withName("float_vector")
- .withDataType(DataType.FloatVector)
- .withDimension(DIM)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("binary_vector")
- .withDataType(DataType.BinaryVector)
- .withDimension(DIM)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("float16_vector")
- .withDataType(DataType.Float16Vector)
- .withDimension(DIM)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("sparse_vector")
- .withDataType(DataType.SparseFloatVector)
- .build());
- // array fields
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_bool")
- .withDataType(DataType.Array)
- .withElementType(DataType.Bool)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_int8")
- .withDataType(DataType.Array)
- .withElementType(DataType.Int8)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_int16")
- .withDataType(DataType.Array)
- .withElementType(DataType.Int16)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_int32")
- .withDataType(DataType.Array)
- .withElementType(DataType.Int32)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_int64")
- .withDataType(DataType.Array)
- .withElementType(DataType.Int64)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_varchar")
- .withDataType(DataType.Array)
- .withElementType(DataType.VarChar)
- .withMaxLength(512)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_float")
- .withDataType(DataType.Array)
- .withElementType(DataType.Float)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- fieldTypes.add(FieldType.newBuilder()
- .withName("array_double")
- .withDataType(DataType.Array)
- .withElementType(DataType.Double)
- .withMaxCapacity(ARRAY_CAPACITY)
- .build());
- CollectionSchemaParam.Builder schemaBuilder = CollectionSchemaParam.newBuilder()
- .withEnableDynamicField(true)
- .withFieldTypes(fieldTypes);
- return schemaBuilder.build();
- }
- private void checkMilvusClientIfExist() {
- if (milvusClient == null) {
- String msg = "milvusClient is null. Please initialize it by calling createConnection() first before use.";
- throw new RuntimeException(msg);
- }
- }
- private static JsonObject convertDataMap(String result) {
- return GSON_INSTANCE.fromJson(result, JsonObject.class);
- }
- }
|