BulkWriterExample.java 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v2;
  20. import com.fasterxml.jackson.annotation.JsonProperty;
  21. import com.fasterxml.jackson.dataformat.csv.CsvMapper;
  22. import com.fasterxml.jackson.dataformat.csv.CsvSchema;
  23. import com.google.common.collect.Lists;
  24. import com.google.gson.Gson;
  25. import com.google.gson.JsonElement;
  26. import com.google.gson.JsonObject;
  27. import com.google.gson.reflect.TypeToken;
  28. import io.milvus.bulkwriter.BulkImport;
  29. import io.milvus.bulkwriter.BulkWriter;
  30. import io.milvus.bulkwriter.LocalBulkWriter;
  31. import io.milvus.bulkwriter.LocalBulkWriterParam;
  32. import io.milvus.bulkwriter.RemoteBulkWriter;
  33. import io.milvus.bulkwriter.RemoteBulkWriterParam;
  34. import io.milvus.bulkwriter.common.clientenum.BulkFileType;
  35. import io.milvus.bulkwriter.common.clientenum.CloudStorage;
  36. import io.milvus.bulkwriter.common.utils.GeneratorUtils;
  37. import io.milvus.bulkwriter.common.utils.ImportUtils;
  38. import io.milvus.bulkwriter.common.utils.ParquetReaderUtils;
  39. import io.milvus.bulkwriter.connect.AzureConnectParam;
  40. import io.milvus.bulkwriter.connect.S3ConnectParam;
  41. import io.milvus.bulkwriter.connect.StorageConnectParam;
  42. import io.milvus.bulkwriter.request.describe.CloudDescribeImportRequest;
  43. import io.milvus.bulkwriter.request.describe.MilvusDescribeImportRequest;
  44. import io.milvus.bulkwriter.request.import_.CloudImportRequest;
  45. import io.milvus.bulkwriter.request.import_.MilvusImportRequest;
  46. import io.milvus.bulkwriter.request.list.CloudListImportJobsRequest;
  47. import io.milvus.bulkwriter.request.list.MilvusListImportJobsRequest;
  48. import io.milvus.v1.CommonUtils;
  49. import io.milvus.v2.client.ConnectConfig;
  50. import io.milvus.v2.client.MilvusClientV2;
  51. import io.milvus.v2.common.ConsistencyLevel;
  52. import io.milvus.v2.common.IndexParam;
  53. import io.milvus.v2.service.collection.request.*;
  54. import io.milvus.v2.service.index.request.CreateIndexReq;
  55. import io.milvus.v2.service.vector.request.QueryReq;
  56. import io.milvus.v2.service.vector.response.QueryResp;
  57. import org.apache.avro.generic.GenericData;
  58. import org.apache.http.util.Asserts;
  59. import java.io.File;
  60. import java.io.IOException;
  61. import java.net.URL;
  62. import java.nio.ByteBuffer;
  63. import java.util.*;
  64. import java.util.concurrent.TimeUnit;
  65. public class BulkWriterExample {
  66. // milvus
  67. public static final String HOST = "127.0.0.1";
  68. public static final Integer PORT = 19530;
  69. public static final String USER_NAME = "user.name";
  70. public static final String PASSWORD = "password";
  71. private static final Gson GSON_INSTANCE = new Gson();
  72. private static final List<Integer> QUERY_IDS = Lists.newArrayList(100, 5000);
  73. /**
  74. * If you need to transfer the files generated by bulkWriter to the corresponding remote storage (AWS S3, GCP GCS, Azure Blob, Aliyun OSS, Tencent Cloud TOS),
  75. * you need to configure it accordingly; Otherwise, you can ignore it.
  76. */
  77. public static class StorageConsts {
  78. public static final CloudStorage cloudStorage = CloudStorage.MINIO;
  79. /**
  80. * If using remote storage such as AWS S3, GCP GCS, Aliyun OSS, Tencent Cloud TOS, Minio
  81. * please configure the following parameters.
  82. */
  83. public static final String STORAGE_ENDPOINT = cloudStorage.getEndpoint("http://127.0.0.1:9000");
  84. public static final String STORAGE_BUCKET = "a-bucket"; // default bucket name of MinIO/Milvus standalone
  85. public static final String STORAGE_ACCESS_KEY = "minioadmin"; // default ak of MinIO/Milvus standalone
  86. public static final String STORAGE_SECRET_KEY = "minioadmin"; // default sk of MinIO/Milvus standalone
  87. /**
  88. * if using remote storage, please configure the parameter
  89. * if using local storage such as Local Minio, please set this parameter to empty.
  90. */
  91. public static final String STORAGE_REGION = "";
  92. /**
  93. * If using remote storage such as Azure Blob
  94. * please configure the following parameters.
  95. */
  96. public static final String AZURE_CONTAINER_NAME = "azure.container.name";
  97. public static final String AZURE_ACCOUNT_NAME = "azure.account.name";
  98. public static final String AZURE_ACCOUNT_KEY = "azure.account.key";
  99. }
  100. /**
  101. * If you have used remoteBulkWriter to generate remote data and want to import data using the Import interface on Zilliz Cloud after generation,
  102. * you don't need to configure the following object-related parameters (OBJECT_URL, OBJECT_ACCESS_KEY, OBJECT_SECRET_KEY). You can call the callCloudImport method, as the internal logic has been encapsulated for you.
  103. * <p>
  104. * If you already have data stored in remote storage (not generated through remoteBulkWriter), and you want to invoke the Import interface on Zilliz Cloud to import data,
  105. * you need to configure the following parameters and then follow the exampleCloudBulkInsert method.
  106. * <p>
  107. * If you do not need to import data through the Import interface on Zilliz Cloud, you can ignore this.
  108. */
  109. public static class CloudImportConsts {
  110. /**
  111. * The value of the URL is fixed.
  112. * For overseas regions, it is: https://api.cloud.zilliz.com
  113. * For regions in China, it is: https://api.cloud.zilliz.com.cn
  114. */
  115. public static final String CLOUD_ENDPOINT = "https://api.cloud.zilliz.com";
  116. public static final String API_KEY = "_api_key_for_cluster_org_";
  117. public static final String CLUSTER_ID = "_your_cloud_cluster_id_";
  118. public static final String COLLECTION_NAME = "_collection_name_on_the_cluster_id_";
  119. // If partition_name is not specified, use ""
  120. public static final String PARTITION_NAME = "_partition_name_on_the_collection_";
  121. /**
  122. * Please provide the complete URL for the file or folder you want to import, similar to https://bucket-name.s3.region-code.amazonaws.com/object-name.
  123. * For more details, you can refer to https://docs.zilliz.com/docs/import-data-on-web-ui.
  124. */
  125. public static final String OBJECT_URL = "_your_storage_object_url_";
  126. public static final String OBJECT_ACCESS_KEY = "_your_storage_access_key_";
  127. public static final String OBJECT_SECRET_KEY = "_your_storage_secret_key_";
  128. }
  129. private static final String SIMPLE_COLLECTION_NAME = "java_sdk_bulkwriter_simple_v2";
  130. private static final String ALL_TYPES_COLLECTION_NAME = "java_sdk_bulkwriter_all_v2";
  131. private static final Integer DIM = 512;
  132. private static final Integer ARRAY_CAPACITY = 10;
  133. private MilvusClientV2 milvusClient;
  134. public static void main(String[] args) throws Exception {
  135. BulkWriterExample exampleBulkWriter = new BulkWriterExample();
  136. exampleBulkWriter.createConnection();
  137. List<BulkFileType> fileTypes = Lists.newArrayList(
  138. BulkFileType.PARQUET
  139. );
  140. exampleSimpleCollection(exampleBulkWriter, fileTypes);
  141. exampleAllTypesCollectionRemote(exampleBulkWriter, fileTypes);
  142. // to call cloud import api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
  143. // exampleCloudImport();
  144. }
  145. private void createConnection() {
  146. System.out.println("\nCreate connection...");
  147. String url = String.format("http://%s:%s", HOST, PORT);
  148. milvusClient = new MilvusClientV2(ConnectConfig.builder()
  149. .uri(url)
  150. .username(USER_NAME)
  151. .password(PASSWORD)
  152. .build());
  153. System.out.println("\nConnected");
  154. }
  155. private static void exampleSimpleCollection(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
  156. CreateCollectionReq.CollectionSchema collectionSchema = exampleBulkWriter.buildSimpleSchema();
  157. exampleBulkWriter.createCollection(SIMPLE_COLLECTION_NAME, collectionSchema, false);
  158. for (BulkFileType fileType : fileTypes) {
  159. localWriter(collectionSchema, fileType);
  160. }
  161. for (BulkFileType fileType : fileTypes) {
  162. remoteWriter(collectionSchema, fileType);
  163. }
  164. // parallel append
  165. parallelAppend(collectionSchema);
  166. }
  167. private static void exampleAllTypesCollectionRemote(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
  168. // 4 types vectors + all scalar types + dynamic field enabled, use bulkInsert interface
  169. for (BulkFileType fileType : fileTypes) {
  170. CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema();
  171. List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(collectionSchema, fileType);
  172. exampleBulkWriter.callBulkInsert(collectionSchema, batchFiles);
  173. exampleBulkWriter.retrieveImportData();
  174. }
  175. // // 4 types vectors + all scalar types + dynamic field enabled, use cloud import api.
  176. // // You need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
  177. // for (BulkFileType fileType : fileTypes) {
  178. // CollectionSchemaParam collectionSchema = buildAllTypesSchema();
  179. // List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(collectionSchema, fileType);
  180. // exampleBulkWriter.createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false);
  181. // exampleBulkWriter.callCloudImport(batchFiles, ALL_TYPES_COLLECTION_NAME, StringUtils.EMPTY);
  182. // exampleBulkWriter.retrieveImportData();
  183. // }
  184. }
  185. private static void localWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws Exception {
  186. System.out.printf("\n===================== local writer (%s) ====================%n", fileType.name());
  187. LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder()
  188. .withCollectionSchema(collectionSchema)
  189. .withLocalPath("/tmp/bulk_writer")
  190. .withFileType(fileType)
  191. .withChunkSize(128 * 1024 * 1024)
  192. .build();
  193. try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) {
  194. // read data from csv
  195. readCsvSampleData("data/train_embeddings.csv", localBulkWriter);
  196. // append rows
  197. for (int i = 0; i < 100000; i++) {
  198. JsonObject row = new JsonObject();
  199. row.addProperty("path", "path_" + i);
  200. row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM)));
  201. row.addProperty("label", "label_" + i);
  202. localBulkWriter.appendRow(row);
  203. }
  204. System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount());
  205. System.out.printf("%s rows in buffer not flushed%n", localBulkWriter.getBufferRowCount());
  206. localBulkWriter.commit(false);
  207. List<List<String>> batchFiles = localBulkWriter.getBatchFiles();
  208. System.out.printf("Local writer done! output local files: %s%n", batchFiles);
  209. } catch (Exception e) {
  210. System.out.println("Local writer catch exception: " + e);
  211. throw e;
  212. }
  213. }
  214. private static void remoteWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws Exception {
  215. System.out.printf("\n===================== remote writer (%s) ====================%n", fileType.name());
  216. try (RemoteBulkWriter remoteBulkWriter = buildRemoteBulkWriter(collectionSchema, fileType)) {
  217. // read data from csv
  218. readCsvSampleData("data/train_embeddings.csv", remoteBulkWriter);
  219. // append rows
  220. for (int i = 0; i < 100000; i++) {
  221. JsonObject row = new JsonObject();
  222. row.addProperty("path", "path_" + i);
  223. row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM)));
  224. row.addProperty("label", "label_" + i);
  225. remoteBulkWriter.appendRow(row);
  226. }
  227. System.out.printf("%s rows appends%n", remoteBulkWriter.getTotalRowCount());
  228. System.out.printf("%s rows in buffer not flushed%n", remoteBulkWriter.getBufferRowCount());
  229. remoteBulkWriter.commit(false);
  230. List<List<String>> batchFiles = remoteBulkWriter.getBatchFiles();
  231. System.out.printf("Remote writer done! output remote files: %s%n", batchFiles);
  232. } catch (Exception e) {
  233. System.out.println("Remote writer catch exception: " + e);
  234. throw e;
  235. }
  236. }
  237. private static void parallelAppend(CreateCollectionReq.CollectionSchema collectionSchema) throws Exception {
  238. System.out.print("\n===================== parallel append ====================");
  239. LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder()
  240. .withCollectionSchema(collectionSchema)
  241. .withLocalPath("/tmp/bulk_writer")
  242. .withFileType(BulkFileType.PARQUET)
  243. .withChunkSize(128 * 1024 * 1024) // 128MB
  244. .build();
  245. try (LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam)) {
  246. List<Thread> threads = new ArrayList<>();
  247. int threadCount = 10;
  248. int rowsPerThread = 1000;
  249. for (int i = 0; i < threadCount; ++i) {
  250. int current = i;
  251. Thread thread = new Thread(() -> appendRow(localBulkWriter, current * rowsPerThread, (current + 1) * rowsPerThread));
  252. threads.add(thread);
  253. thread.start();
  254. System.out.printf("Thread %s started%n", thread.getName());
  255. }
  256. for (Thread thread : threads) {
  257. thread.join();
  258. System.out.printf("Thread %s finished%n", thread.getName());
  259. }
  260. System.out.println(localBulkWriter.getTotalRowCount() + " rows appends");
  261. System.out.println(localBulkWriter.getBufferRowCount() + " rows in buffer not flushed");
  262. localBulkWriter.commit(false);
  263. System.out.printf("Append finished, %s rows%n", threadCount * rowsPerThread);
  264. long rowCount = 0;
  265. List<List<String>> batchFiles = localBulkWriter.getBatchFiles();
  266. for (List<String> batch : batchFiles) {
  267. for (String filePath : batch) {
  268. rowCount += readParquet(filePath);
  269. }
  270. }
  271. Asserts.check(rowCount == threadCount * rowsPerThread, String.format("rowCount %s not equals expected %s", rowCount, threadCount * rowsPerThread));
  272. System.out.println("Data is correct");
  273. } catch (Exception e) {
  274. System.out.println("parallelAppend catch exception: " + e);
  275. throw e;
  276. }
  277. }
  278. private static long readParquet(String localFilePath) throws Exception {
  279. final long[] rowCount = {0};
  280. new ParquetReaderUtils() {
  281. @Override
  282. public void readRecord(GenericData.Record record) {
  283. rowCount[0]++;
  284. String pathValue = record.get("path").toString();
  285. String labelValue = record.get("label").toString();
  286. Asserts.check(pathValue.replace("path_", "").equals(labelValue.replace("label_", "")), String.format("the suffix of %s not equals the suffix of %s", pathValue, labelValue));
  287. }
  288. }.readParquet(localFilePath);
  289. System.out.printf("The file %s contains %s rows. Verify the content...%n", localFilePath, rowCount[0]);
  290. return rowCount[0];
  291. }
  292. private static void appendRow(LocalBulkWriter writer, int begin, int end) {
  293. try {
  294. for (int i = begin; i < end; ++i) {
  295. JsonObject row = new JsonObject();
  296. row.addProperty("path", "path_" + i);
  297. row.add("vector", GSON_INSTANCE.toJsonTree(GeneratorUtils.genFloatVector(DIM)));
  298. row.addProperty("label", "label_" + i);
  299. writer.appendRow(row);
  300. if (i % 100 == 0) {
  301. System.out.printf("%s inserted %s items%n", Thread.currentThread().getName(), i - begin);
  302. }
  303. }
  304. } catch (Exception e) {
  305. System.out.println("failed to append row!");
  306. }
  307. }
  308. private List<List<String>> allTypesRemoteWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws Exception {
  309. System.out.printf("\n===================== all field types (%s) ====================%n", fileType.name());
  310. try (RemoteBulkWriter remoteBulkWriter = buildRemoteBulkWriter(collectionSchema, fileType)) {
  311. System.out.println("Append rows");
  312. int batchCount = 10000;
  313. for (int i = 0; i < batchCount; ++i) {
  314. JsonObject rowObject = new JsonObject();
  315. // scalar field
  316. rowObject.addProperty("id", i);
  317. rowObject.addProperty("bool", i % 5 == 0);
  318. rowObject.addProperty("int8", i % 128);
  319. rowObject.addProperty("int16", i % 1000);
  320. rowObject.addProperty("int32", i % 100000);
  321. rowObject.addProperty("float", i / 3);
  322. rowObject.addProperty("double", i / 7);
  323. rowObject.addProperty("varchar", "varchar_" + i);
  324. rowObject.addProperty("json", String.format("{\"dummy\": %s, \"ok\": \"name_%s\"}", i, i));
  325. // vector field
  326. rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateFloatVector(DIM)));
  327. rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateBinaryVector(DIM).array()));
  328. rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateFloat16Vector(DIM, false).array()));
  329. rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateSparseVector()));
  330. // array field
  331. rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorBoolValue(10)));
  332. rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt8Value(10)));
  333. rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt16Value(10)));
  334. rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt32Value(10)));
  335. rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorLongValue(10)));
  336. rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorVarcharValue(10, 10)));
  337. rowObject.add("array_float", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorFloatValue(10)));
  338. rowObject.add("array_double", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorDoubleValue(10)));
  339. // dynamic fields
  340. if (collectionSchema.isEnableDynamicField()) {
  341. rowObject.addProperty("dynamic", "dynamic_" + i);
  342. }
  343. if (QUERY_IDS.contains(i)) {
  344. System.out.println(rowObject);
  345. }
  346. remoteBulkWriter.appendRow(rowObject);
  347. }
  348. System.out.printf("%s rows appends%n", remoteBulkWriter.getTotalRowCount());
  349. System.out.printf("%s rows in buffer not flushed%n", remoteBulkWriter.getBufferRowCount());
  350. System.out.println("Generate data files...");
  351. remoteBulkWriter.commit(false);
  352. System.out.printf("Data files have been uploaded: %s%n", remoteBulkWriter.getBatchFiles());
  353. return remoteBulkWriter.getBatchFiles();
  354. } catch (Exception e) {
  355. System.out.println("allTypesRemoteWriter catch exception: " + e);
  356. throw e;
  357. }
  358. }
  359. private static RemoteBulkWriter buildRemoteBulkWriter(CreateCollectionReq.CollectionSchema collectionSchema, BulkFileType fileType) throws IOException {
  360. StorageConnectParam connectParam = buildStorageConnectParam();
  361. RemoteBulkWriterParam bulkWriterParam = RemoteBulkWriterParam.newBuilder()
  362. .withCollectionSchema(collectionSchema)
  363. .withRemotePath("bulk_data")
  364. .withFileType(fileType)
  365. .withChunkSize(512 * 1024 * 1024)
  366. .withConnectParam(connectParam)
  367. .build();
  368. return new RemoteBulkWriter(bulkWriterParam);
  369. }
  370. private static StorageConnectParam buildStorageConnectParam() {
  371. StorageConnectParam connectParam;
  372. if (StorageConsts.cloudStorage == CloudStorage.AZURE) {
  373. String connectionStr = "DefaultEndpointsProtocol=https;AccountName=" + StorageConsts.AZURE_ACCOUNT_NAME +
  374. ";AccountKey=" + StorageConsts.AZURE_ACCOUNT_KEY + ";EndpointSuffix=core.windows.net";
  375. connectParam = AzureConnectParam.newBuilder()
  376. .withConnStr(connectionStr)
  377. .withContainerName(StorageConsts.AZURE_CONTAINER_NAME)
  378. .build();
  379. } else {
  380. connectParam = S3ConnectParam.newBuilder()
  381. .withEndpoint(StorageConsts.STORAGE_ENDPOINT)
  382. .withCloudName(StorageConsts.cloudStorage.getCloudName())
  383. .withBucketName(StorageConsts.STORAGE_BUCKET)
  384. .withAccessKey(StorageConsts.STORAGE_ACCESS_KEY)
  385. .withSecretKey(StorageConsts.STORAGE_SECRET_KEY)
  386. .withRegion(StorageConsts.STORAGE_REGION)
  387. .build();
  388. }
  389. return connectParam;
  390. }
  391. private static void readCsvSampleData(String filePath, BulkWriter writer) throws IOException, InterruptedException {
  392. ClassLoader classLoader = BulkWriterExample.class.getClassLoader();
  393. URL resourceUrl = classLoader.getResource(filePath);
  394. filePath = new File(resourceUrl.getFile()).getAbsolutePath();
  395. CsvMapper csvMapper = new CsvMapper();
  396. File csvFile = new File(filePath);
  397. CsvSchema csvSchema = CsvSchema.builder().setUseHeader(true).build();
  398. Iterator<CsvDataObject> iterator = csvMapper.readerFor(CsvDataObject.class).with(csvSchema).readValues(csvFile);
  399. while (iterator.hasNext()) {
  400. CsvDataObject dataObject = iterator.next();
  401. JsonObject row = new JsonObject();
  402. row.add("vector", GSON_INSTANCE.toJsonTree(dataObject.toFloatArray()));
  403. row.addProperty("label", dataObject.getLabel());
  404. row.addProperty("path", dataObject.getPath());
  405. writer.appendRow(row);
  406. }
  407. }
  408. private static class CsvDataObject {
  409. @JsonProperty
  410. private String vector;
  411. @JsonProperty
  412. private String path;
  413. @JsonProperty
  414. private String label;
  415. public String getVector() {
  416. return vector;
  417. }
  418. public String getPath() {
  419. return path;
  420. }
  421. public String getLabel() {
  422. return label;
  423. }
  424. public List<Float> toFloatArray() {
  425. return GSON_INSTANCE.fromJson(vector, new TypeToken<List<Float>>() {
  426. }.getType());
  427. }
  428. }
  429. private void callBulkInsert(CreateCollectionReq.CollectionSchema collectionSchema, List<List<String>> batchFiles) throws InterruptedException {
  430. createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, true);
  431. String url = String.format("http://%s:%s", HOST, PORT);
  432. System.out.println("\n===================== import files to milvus ====================");
  433. MilvusImportRequest milvusImportRequest = MilvusImportRequest.builder()
  434. .collectionName(ALL_TYPES_COLLECTION_NAME)
  435. .files(batchFiles)
  436. .build();
  437. String bulkImportResult = BulkImport.bulkImport(url, milvusImportRequest);
  438. System.out.println(bulkImportResult);
  439. JsonObject bulkImportObject = convertJsonObject(bulkImportResult);
  440. String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString();
  441. System.out.println("Create a bulkInert task, job id: " + jobId);
  442. System.out.println("\n===================== listBulkInsertJobs() ====================");
  443. MilvusListImportJobsRequest listImportJobsRequest = MilvusListImportJobsRequest.builder().collectionName(ALL_TYPES_COLLECTION_NAME).build();
  444. String listImportJobsResult = BulkImport.listImportJobs(url, listImportJobsRequest);
  445. System.out.println(listImportJobsResult);
  446. while (true) {
  447. System.out.println("Wait 5 second to check bulkInsert job state...");
  448. TimeUnit.SECONDS.sleep(5);
  449. System.out.println("\n===================== getBulkInsertState() ====================");
  450. MilvusDescribeImportRequest request = MilvusDescribeImportRequest.builder()
  451. .jobId(jobId)
  452. .build();
  453. String getImportProgressResult = BulkImport.getImportProgress(url, request);
  454. System.out.println(getImportProgressResult);
  455. JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult);
  456. String state = getImportProgressObject.getAsJsonObject("data").get("state").getAsString();
  457. String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString();
  458. if ("Failed".equals(state)) {
  459. String reason = getImportProgressObject.getAsJsonObject("data").get("reason").getAsString();
  460. System.out.printf("The job %s failed, reason: %s%n", jobId, reason);
  461. break;
  462. } else if ("Completed".equals(state)) {
  463. System.out.printf("The job %s completed%n", jobId);
  464. break;
  465. } else {
  466. System.out.printf("The job %s is running, state:%s progress:%s%n", jobId, state, progress);
  467. }
  468. }
  469. System.out.println("Collection row number: " + getCollectionStatistics());
  470. }
  471. private void callCloudImport(List<List<String>> batchFiles, String collectionName, String partitionName) throws InterruptedException {
  472. String objectUrl = StorageConsts.cloudStorage == CloudStorage.AZURE
  473. ? StorageConsts.cloudStorage.getAzureObjectUrl(StorageConsts.AZURE_ACCOUNT_NAME, StorageConsts.AZURE_CONTAINER_NAME, ImportUtils.getCommonPrefix(batchFiles))
  474. : StorageConsts.cloudStorage.getS3ObjectUrl(StorageConsts.STORAGE_BUCKET, ImportUtils.getCommonPrefix(batchFiles), StorageConsts.STORAGE_REGION);
  475. String accessKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_NAME : StorageConsts.STORAGE_ACCESS_KEY;
  476. String secretKey = StorageConsts.cloudStorage == CloudStorage.AZURE ? StorageConsts.AZURE_ACCOUNT_KEY : StorageConsts.STORAGE_SECRET_KEY;
  477. System.out.println("\n===================== call cloudImport ====================");
  478. CloudImportRequest bulkImportRequest = CloudImportRequest.builder()
  479. .objectUrl(objectUrl).accessKey(accessKey).secretKey(secretKey)
  480. .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(collectionName).partitionName(partitionName)
  481. .apiKey(CloudImportConsts.API_KEY)
  482. .build();
  483. String bulkImportResult = BulkImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, bulkImportRequest);
  484. JsonObject bulkImportObject = convertJsonObject(bulkImportResult);
  485. String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString();
  486. System.out.println("Create a cloudImport job, job id: " + jobId);
  487. System.out.println("\n===================== call cloudListImportJobs ====================");
  488. CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).apiKey(CloudImportConsts.API_KEY).build();
  489. String listImportJobsResult = BulkImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest);
  490. System.out.println(listImportJobsResult);
  491. while (true) {
  492. System.out.println("Wait 5 second to check bulkInsert job state...");
  493. TimeUnit.SECONDS.sleep(5);
  494. System.out.println("\n===================== call cloudGetProgress ====================");
  495. CloudDescribeImportRequest request = CloudDescribeImportRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).apiKey(CloudImportConsts.API_KEY).build();
  496. String getImportProgressResult = BulkImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, request);
  497. JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult);
  498. String importProgressState = getImportProgressObject.getAsJsonObject("data").get("state").getAsString();
  499. String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString();
  500. if ("Failed".equals(importProgressState)) {
  501. String reason = getImportProgressObject.getAsJsonObject("data").get("reason").getAsString();
  502. System.out.printf("The job %s failed, reason: %s%n", jobId, reason);
  503. break;
  504. } else if ("Completed".equals(importProgressState)) {
  505. System.out.printf("The job %s completed%n", jobId);
  506. break;
  507. } else {
  508. System.out.printf("The job %s is running, state:%s progress:%s%n", jobId, importProgressState, progress);
  509. }
  510. }
  511. System.out.println("Collection row number: " + getCollectionStatistics());
  512. }
  513. /**
  514. * @param collectionSchema collection info
  515. * @param dropIfExist if collection already exist, will drop firstly and then create again
  516. */
  517. private void createCollection(String collectionName, CreateCollectionReq.CollectionSchema collectionSchema, boolean dropIfExist) {
  518. System.out.println("\n===================== create collection ====================");
  519. checkMilvusClientIfExist();
  520. CreateCollectionReq requestCreate = CreateCollectionReq.builder()
  521. .collectionName(collectionName)
  522. .collectionSchema(collectionSchema)
  523. .consistencyLevel(ConsistencyLevel.BOUNDED)
  524. .build();
  525. Boolean has = milvusClient.hasCollection(HasCollectionReq.builder().collectionName(collectionName).build());
  526. if (has) {
  527. if (dropIfExist) {
  528. milvusClient.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build());
  529. milvusClient.createCollection(requestCreate);
  530. }
  531. } else {
  532. milvusClient.createCollection(requestCreate);
  533. }
  534. System.out.printf("Collection %s created%n", collectionName);
  535. }
  536. private void retrieveImportData() {
  537. createIndex();
  538. System.out.printf("Load collection and query items %s%n", QUERY_IDS);
  539. loadCollection();
  540. String expr = String.format("id in %s", QUERY_IDS);
  541. System.out.println(expr);
  542. List<QueryResp.QueryResult> results = query(expr, Lists.newArrayList("*"));
  543. System.out.println("Query results:");
  544. for (QueryResp.QueryResult result : results) {
  545. Map<String, Object> entity = result.getEntity();
  546. JsonObject rowObject = new JsonObject();
  547. // scalar field
  548. rowObject.addProperty("id", (Long)entity.get("id"));
  549. rowObject.addProperty("bool", (Boolean) entity.get("bool"));
  550. rowObject.addProperty("int8", (Integer) entity.get("int8"));
  551. rowObject.addProperty("int16", (Integer) entity.get("int16"));
  552. rowObject.addProperty("int32", (Integer) entity.get("int32"));
  553. rowObject.addProperty("float", (Float) entity.get("float"));
  554. rowObject.addProperty("double", (Double) entity.get("double"));
  555. rowObject.addProperty("varchar", (String) entity.get("varchar"));
  556. rowObject.add("json", (JsonElement) entity.get("json"));
  557. // vector field
  558. rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(entity.get("float_vector")));
  559. rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(((ByteBuffer)entity.get("binary_vector")).array()));
  560. rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(((ByteBuffer)entity.get("float16_vector")).array()));
  561. rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(entity.get("sparse_vector")));
  562. // array field
  563. rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(entity.get("array_bool")));
  564. rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(entity.get("array_int8")));
  565. rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(entity.get("array_int16")));
  566. rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(entity.get("array_int32")));
  567. rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(entity.get("array_int64")));
  568. rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(entity.get("array_varchar")));
  569. rowObject.add("array_float", GSON_INSTANCE.toJsonTree(entity.get("array_float")));
  570. rowObject.add("array_double", GSON_INSTANCE.toJsonTree(entity.get("array_double")));
  571. // dynamic field
  572. rowObject.addProperty("dynamic", (String) entity.get("dynamic"));
  573. System.out.println(rowObject);
  574. }
  575. }
  576. private void createIndex() {
  577. System.out.println("Create index...");
  578. checkMilvusClientIfExist();
  579. List<IndexParam> indexes = new ArrayList<>();
  580. indexes.add(IndexParam.builder()
  581. .fieldName("float_vector")
  582. .indexType(IndexParam.IndexType.FLAT)
  583. .metricType(IndexParam.MetricType.L2)
  584. .build());
  585. indexes.add(IndexParam.builder()
  586. .fieldName("binary_vector")
  587. .indexType(IndexParam.IndexType.BIN_FLAT)
  588. .metricType(IndexParam.MetricType.HAMMING)
  589. .build());
  590. indexes.add(IndexParam.builder()
  591. .fieldName("float16_vector")
  592. .indexType(IndexParam.IndexType.FLAT)
  593. .metricType(IndexParam.MetricType.IP)
  594. .build());
  595. indexes.add(IndexParam.builder()
  596. .fieldName("sparse_vector")
  597. .indexType(IndexParam.IndexType.SPARSE_WAND)
  598. .metricType(IndexParam.MetricType.IP)
  599. .build());
  600. milvusClient.createIndex(CreateIndexReq.builder()
  601. .indexParams(indexes)
  602. .build());
  603. }
  604. private void loadCollection() {
  605. System.out.println("Loading Collection...");
  606. checkMilvusClientIfExist();
  607. milvusClient.loadCollection(LoadCollectionReq.builder()
  608. .collectionName(ALL_TYPES_COLLECTION_NAME)
  609. .build());
  610. }
  611. private List<QueryResp.QueryResult> query(String expr, List<String> outputFields) {
  612. System.out.println("========== query() ==========");
  613. checkMilvusClientIfExist();
  614. QueryReq test = QueryReq.builder()
  615. .collectionName(ALL_TYPES_COLLECTION_NAME)
  616. .filter(expr)
  617. .outputFields(outputFields)
  618. .build();
  619. QueryResp response = milvusClient.query(test);
  620. return response.getQueryResults();
  621. }
  622. private Long getCollectionStatistics() {
  623. System.out.println("========== getCollectionStatistics() ==========");
  624. checkMilvusClientIfExist();
  625. // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
  626. QueryResp countR = milvusClient.query(QueryReq.builder()
  627. .collectionName(ALL_TYPES_COLLECTION_NAME)
  628. .filter("")
  629. .outputFields(Collections.singletonList("count(*)"))
  630. .consistencyLevel(ConsistencyLevel.STRONG)
  631. .build());
  632. return (long)countR.getQueryResults().get(0).getEntity().get("count(*)");
  633. }
  634. private static void exampleCloudImport() {
  635. System.out.println("\n===================== import files to cloud vectordb ====================");
  636. CloudImportRequest request = CloudImportRequest.builder()
  637. .objectUrl(CloudImportConsts.OBJECT_URL).accessKey(CloudImportConsts.OBJECT_ACCESS_KEY).secretKey(CloudImportConsts.OBJECT_SECRET_KEY)
  638. .clusterId(CloudImportConsts.CLUSTER_ID).collectionName(CloudImportConsts.COLLECTION_NAME).partitionName(CloudImportConsts.PARTITION_NAME)
  639. .apiKey(CloudImportConsts.API_KEY)
  640. .build();
  641. String bulkImportResult = BulkImport.bulkImport(CloudImportConsts.CLOUD_ENDPOINT, request);
  642. System.out.println(bulkImportResult);
  643. System.out.println("\n===================== get import job progress ====================");
  644. JsonObject bulkImportObject = convertJsonObject(bulkImportResult);
  645. String jobId = bulkImportObject.getAsJsonObject("data").get("jobId").getAsString();
  646. CloudDescribeImportRequest getImportProgressRequest = CloudDescribeImportRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).jobId(jobId).apiKey(CloudImportConsts.API_KEY).build();
  647. String getImportProgressResult = BulkImport.getImportProgress(CloudImportConsts.CLOUD_ENDPOINT, getImportProgressRequest);
  648. System.out.println(getImportProgressResult);
  649. System.out.println("\n===================== list import jobs ====================");
  650. CloudListImportJobsRequest listImportJobsRequest = CloudListImportJobsRequest.builder().clusterId(CloudImportConsts.CLUSTER_ID).currentPage(1).pageSize(10).apiKey(CloudImportConsts.API_KEY).build();
  651. String listImportJobsResult = BulkImport.listImportJobs(CloudImportConsts.CLOUD_ENDPOINT, listImportJobsRequest);
  652. System.out.println(listImportJobsResult);
  653. }
  654. private CreateCollectionReq.CollectionSchema buildSimpleSchema() {
  655. CreateCollectionReq.CollectionSchema schemaV2 = CreateCollectionReq.CollectionSchema.builder()
  656. .build();
  657. schemaV2.addField(AddFieldReq.builder()
  658. .fieldName("id")
  659. .dataType(io.milvus.v2.common.DataType.Int64)
  660. .isPrimaryKey(Boolean.TRUE)
  661. .autoID(true)
  662. .build());
  663. schemaV2.addField(AddFieldReq.builder()
  664. .fieldName("path")
  665. .dataType(io.milvus.v2.common.DataType.VarChar)
  666. .maxLength(512)
  667. .build());
  668. schemaV2.addField(AddFieldReq.builder()
  669. .fieldName("label")
  670. .dataType(io.milvus.v2.common.DataType.VarChar)
  671. .maxLength(512)
  672. .build());
  673. schemaV2.addField(AddFieldReq.builder()
  674. .fieldName("vector")
  675. .dataType(io.milvus.v2.common.DataType.FloatVector)
  676. .dimension(DIM)
  677. .build());
  678. return schemaV2;
  679. }
  680. private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() {
  681. CreateCollectionReq.CollectionSchema schemaV2 = CreateCollectionReq.CollectionSchema.builder()
  682. .enableDynamicField(true)
  683. .build();
  684. // scalar field
  685. schemaV2.addField(AddFieldReq.builder()
  686. .fieldName("id")
  687. .dataType(io.milvus.v2.common.DataType.Int64)
  688. .isPrimaryKey(Boolean.TRUE)
  689. .autoID(false)
  690. .build());
  691. schemaV2.addField(AddFieldReq.builder()
  692. .fieldName("bool")
  693. .dataType(io.milvus.v2.common.DataType.Bool)
  694. .build());
  695. schemaV2.addField(AddFieldReq.builder()
  696. .fieldName("int8")
  697. .dataType(io.milvus.v2.common.DataType.Int8)
  698. .build());
  699. schemaV2.addField(AddFieldReq.builder()
  700. .fieldName("int16")
  701. .dataType(io.milvus.v2.common.DataType.Int16)
  702. .build());
  703. schemaV2.addField(AddFieldReq.builder()
  704. .fieldName("int32")
  705. .dataType(io.milvus.v2.common.DataType.Int32)
  706. .build());
  707. schemaV2.addField(AddFieldReq.builder()
  708. .fieldName("float")
  709. .dataType(io.milvus.v2.common.DataType.Float)
  710. .build());
  711. schemaV2.addField(AddFieldReq.builder()
  712. .fieldName("double")
  713. .dataType(io.milvus.v2.common.DataType.Double)
  714. .build());
  715. schemaV2.addField(AddFieldReq.builder()
  716. .fieldName("varchar")
  717. .dataType(io.milvus.v2.common.DataType.VarChar)
  718. .maxLength(512)
  719. .build());
  720. schemaV2.addField(AddFieldReq.builder()
  721. .fieldName("json")
  722. .dataType(io.milvus.v2.common.DataType.JSON)
  723. .build());
  724. // vector fields
  725. schemaV2.addField(AddFieldReq.builder()
  726. .fieldName("float_vector")
  727. .dataType(io.milvus.v2.common.DataType.FloatVector)
  728. .dimension(DIM)
  729. .build());
  730. schemaV2.addField(AddFieldReq.builder()
  731. .fieldName("binary_vector")
  732. .dataType(io.milvus.v2.common.DataType.BinaryVector)
  733. .dimension(DIM)
  734. .build());
  735. schemaV2.addField(AddFieldReq.builder()
  736. .fieldName("float16_vector")
  737. .dataType(io.milvus.v2.common.DataType.Float16Vector)
  738. .dimension(DIM)
  739. .build());
  740. schemaV2.addField(AddFieldReq.builder()
  741. .fieldName("sparse_vector")
  742. .dataType(io.milvus.v2.common.DataType.SparseFloatVector)
  743. .build());
  744. // array fields
  745. schemaV2.addField(AddFieldReq.builder()
  746. .fieldName("array_bool")
  747. .dataType(io.milvus.v2.common.DataType.Array)
  748. .maxCapacity(ARRAY_CAPACITY)
  749. .elementType(io.milvus.v2.common.DataType.Bool)
  750. .build());
  751. schemaV2.addField(AddFieldReq.builder()
  752. .fieldName("array_int8")
  753. .dataType(io.milvus.v2.common.DataType.Array)
  754. .maxCapacity(ARRAY_CAPACITY)
  755. .elementType(io.milvus.v2.common.DataType.Int8)
  756. .build());
  757. schemaV2.addField(AddFieldReq.builder()
  758. .fieldName("array_int16")
  759. .dataType(io.milvus.v2.common.DataType.Array)
  760. .maxCapacity(ARRAY_CAPACITY)
  761. .elementType(io.milvus.v2.common.DataType.Int16)
  762. .build());
  763. schemaV2.addField(AddFieldReq.builder()
  764. .fieldName("array_int32")
  765. .dataType(io.milvus.v2.common.DataType.Array)
  766. .maxCapacity(ARRAY_CAPACITY)
  767. .elementType(io.milvus.v2.common.DataType.Int32)
  768. .build());
  769. schemaV2.addField(AddFieldReq.builder()
  770. .fieldName("array_int64")
  771. .dataType(io.milvus.v2.common.DataType.Array)
  772. .maxCapacity(ARRAY_CAPACITY)
  773. .elementType(io.milvus.v2.common.DataType.Int64)
  774. .build());
  775. schemaV2.addField(AddFieldReq.builder()
  776. .fieldName("array_varchar")
  777. .dataType(io.milvus.v2.common.DataType.Array)
  778. .maxCapacity(ARRAY_CAPACITY)
  779. .elementType(io.milvus.v2.common.DataType.VarChar)
  780. .maxLength(512)
  781. .build());
  782. schemaV2.addField(AddFieldReq.builder()
  783. .fieldName("array_float")
  784. .dataType(io.milvus.v2.common.DataType.Array)
  785. .maxCapacity(ARRAY_CAPACITY)
  786. .elementType(io.milvus.v2.common.DataType.Float)
  787. .build());
  788. schemaV2.addField(AddFieldReq.builder()
  789. .fieldName("array_double")
  790. .dataType(io.milvus.v2.common.DataType.Array)
  791. .maxCapacity(ARRAY_CAPACITY)
  792. .elementType(io.milvus.v2.common.DataType.Double)
  793. .build());
  794. return schemaV2;
  795. }
  796. private void checkMilvusClientIfExist() {
  797. if (milvusClient == null) {
  798. String msg = "milvusClient is null. Please initialize it by calling createConnection() first before use.";
  799. throw new RuntimeException(msg);
  800. }
  801. }
  802. private static JsonObject convertJsonObject(String result) {
  803. return GSON_INSTANCE.fromJson(result, JsonObject.class);
  804. }
  805. }