MilvusClientExample.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. import com.google.common.util.concurrent.ListenableFuture;
  20. import com.google.gson.JsonObject;
  21. import io.milvus.client.*;
  22. import java.util.ArrayList;
  23. import java.util.List;
  24. import java.util.SplittableRandom;
  25. import java.util.concurrent.ExecutionException;
  26. import java.util.stream.Collectors;
  27. import java.util.stream.DoubleStream;
  28. // This is a simple example demonstrating how to use the Milvus Java SDK.
  29. // For detailed API document, please refer to
  30. // https://milvus-io.github.io/milvus-sdk-java/javadoc/io/milvus/client/package-summary.html
  31. // You can also find more information in https://milvus.io/
  32. public class MilvusClientExample {
  33. // Helper function that generates random vectors
  34. static List<List<Float>> generateVectors(long vectorCount, long dimension) {
  35. SplittableRandom splitcollectionRandom = new SplittableRandom();
  36. List<List<Float>> vectors = new ArrayList<>();
  37. for (long i = 0; i < vectorCount; ++i) {
  38. splitcollectionRandom = splitcollectionRandom.split();
  39. DoubleStream doubleStream = splitcollectionRandom.doubles(dimension);
  40. List<Float> vector =
  41. doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
  42. vectors.add(vector);
  43. }
  44. return vectors;
  45. }
  46. // Helper function that normalizes a vector if you are using IP (Inner Product) as your metric
  47. // type
  48. static List<Float> normalizeVector(List<Float> vector) {
  49. float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
  50. final float norm = (float) Math.sqrt(squareSum);
  51. vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
  52. return vector;
  53. }
  54. public static void main(String[] args) throws InterruptedException, ConnectFailedException {
  55. // You may need to change the following to the host and port of your Milvus server
  56. String host = "localhost";
  57. int port = 19530;
  58. if (args.length >= 2) {
  59. host = args[0];
  60. port = Integer.parseInt(args[1]);
  61. }
  62. // Create Milvus client
  63. MilvusClient client = new MilvusGrpcClient();
  64. // Connect to Milvus server
  65. ConnectParam connectParam = new ConnectParam.Builder().withHost(host).withPort(port).build();
  66. try {
  67. Response connectResponse = client.connect(connectParam);
  68. } catch (ConnectFailedException e) {
  69. System.out.println("Failed to connect to Milvus server: " + e.toString());
  70. throw e;
  71. }
  72. // Check whether we are connected
  73. boolean connected = client.isConnected();
  74. // Create a collection with the following collection mapping
  75. final String collectionName = "example"; // collection name
  76. final long dimension = 128; // dimension of each vector
  77. final long indexFileSize = 1024; // maximum size (in MB) of each index file
  78. final MetricType metricType = MetricType.IP; // we choose IP (Inner Product) as our metric type
  79. CollectionMapping collectionMapping =
  80. new CollectionMapping.Builder(collectionName, dimension)
  81. .withIndexFileSize(indexFileSize)
  82. .withMetricType(metricType)
  83. .build();
  84. Response createCollectionResponse = client.createCollection(collectionMapping);
  85. // Check whether the collection exists
  86. HasCollectionResponse hasCollectionResponse = client.hasCollection(collectionName);
  87. // Get collection info
  88. GetCollectionInfoResponse getCollectionInfoResponse =
  89. client.getCollectionInfo(collectionName);
  90. // Insert randomly generated vectors to collection
  91. final int vectorCount = 100000;
  92. List<List<Float>> vectors = generateVectors(vectorCount, dimension);
  93. vectors =
  94. vectors.stream().map(MilvusClientExample::normalizeVector).collect(Collectors.toList());
  95. InsertParam insertParam =
  96. new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
  97. InsertResponse insertResponse = client.insert(insertParam);
  98. // Insert returns a list of vector ids that you will be using (if you did not supply them
  99. // yourself) to reference the vectors you just inserted
  100. List<Long> vectorIds = insertResponse.getVectorIds();
  101. // Flush data in collection
  102. Response flushResponse = client.flush(collectionName);
  103. // Get current entity count of collection
  104. CountEntitiesResponse ountEntitiesResponse =
  105. client.countEntities(collectionName);
  106. // Create index for the collection
  107. // We choose IVF_SQ8 as our index type here. Refer to IndexType javadoc for a
  108. // complete explanation of different index types
  109. final IndexType indexType = IndexType.IVF_SQ8;
  110. // Each index type has its optional parameters you can set. Refer to the Milvus documentation
  111. // for how to set the optimal parameters based on your needs.
  112. JsonObject indexParamsJson = new JsonObject();
  113. indexParamsJson.addProperty("nlist", 16384);
  114. Index index =
  115. new Index.Builder(collectionName, indexType)
  116. .withParamsInJson(indexParamsJson.toString())
  117. .build();
  118. Response createIndexResponse = client.createIndex(index);
  119. // Get index info for your collection
  120. GetIndexInfoResponse getIndexInfoResponse = client.getIndexInfo(collectionName);
  121. System.out.format("Index Info: %s\n", getIndexInfoResponse.getIndex().toString());
  122. // Get collection info
  123. Response getCollectionStatsResponse = client.getCollectionStats(collectionName);
  124. if (getCollectionStatsResponse.ok()) {
  125. // Collection info is sent back with JSON type string
  126. String jsonString = getCollectionStatsResponse.getMessage();
  127. System.out.format("Collection Stats: %s\n", jsonString);
  128. }
  129. // Check whether a partition exists in collection
  130. // Obviously we do not have partition "tag" now
  131. HasPartitionResponse testHasPartition = client.hasPartition(collectionName, "tag");
  132. if (testHasPartition.ok() && testHasPartition.hasPartition()) {
  133. throw new AssertionError("Wrong results!");
  134. }
  135. // Search vectors
  136. // Searching the first 5 vectors of the vectors we just inserted
  137. final int searchBatchSize = 5;
  138. List<List<Float>> vectorsToSearch = vectors.subList(0, searchBatchSize);
  139. final long topK = 10;
  140. // Based on the index you created, the available search parameters will be different. Refer to
  141. // the Milvus documentation for how to set the optimal parameters based on your needs.
  142. JsonObject searchParamsJson = new JsonObject();
  143. searchParamsJson.addProperty("nprobe", 20);
  144. SearchParam searchParam =
  145. new SearchParam.Builder(collectionName)
  146. .withFloatVectors(vectorsToSearch)
  147. .withTopK(topK)
  148. .withParamsInJson(searchParamsJson.toString())
  149. .build();
  150. SearchResponse searchResponse = client.search(searchParam);
  151. if (searchResponse.ok()) {
  152. List<List<SearchResponse.QueryResult>> queryResultsList =
  153. searchResponse.getQueryResultsList();
  154. final double epsilon = 0.001;
  155. for (int i = 0; i < searchBatchSize; i++) {
  156. // Since we are searching for vector that is already present in the collection,
  157. // the first result vector should be itself and the distance (inner product) should be
  158. // very close to 1 (some precision is lost during the process)
  159. SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
  160. if (firstQueryResult.getVectorId() != vectorIds.get(i)
  161. || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
  162. throw new AssertionError("Wrong results!");
  163. }
  164. }
  165. }
  166. // You can also get result ids and distances separately
  167. List<List<Long>> resultIds = searchResponse.getResultIdsList();
  168. List<List<Float>> resultDistances = searchResponse.getResultDistancesList();
  169. // You can send search request asynchronously, which returns a ListenableFuture object
  170. ListenableFuture<SearchResponse> searchResponseFuture = client.searchAsync(searchParam);
  171. try {
  172. // Get search response immediately. Obviously you will want to do more complicated stuff with
  173. // ListenableFuture
  174. searchResponseFuture.get();
  175. } catch (ExecutionException e) {
  176. e.printStackTrace();
  177. }
  178. // Delete the first 5 of vectors you just searched
  179. Response deleteByIdsResponse =
  180. client.deleteEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
  181. // Flush again, so deletions to data become visible
  182. flushResponse = client.flush(collectionName);
  183. // Try to get the corresponding vector of the first id you just deleted.
  184. GetEntityByIDResponse getEntityByIDResponse =
  185. client.getEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
  186. // Obviously you won't get anything
  187. if (!getEntityByIDResponse.getFloatVectors().get(0).isEmpty()) {
  188. throw new AssertionError("This can never happen!");
  189. }
  190. // Compact the collection, erasing deleted data from disk and rebuild index in background (if
  191. // the data size after compaction is still larger than indexFileSize). Data was only
  192. // soft-deleted until you call compact.
  193. Response compactResponse = client.compact(collectionName);
  194. // Drop index for the collection
  195. Response dropIndexResponse = client.dropIndex(collectionName);
  196. // Drop collection
  197. Response dropCollectionResponse = client.dropCollection(collectionName);
  198. // Disconnect from Milvus server
  199. try {
  200. Response disconnectResponse = client.disconnect();
  201. } catch (InterruptedException e) {
  202. System.out.println("Failed to disconnect: " + e.toString());
  203. throw e;
  204. }
  205. }
  206. }