MilvusClientExample.java 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. import com.google.common.util.concurrent.Futures;
  20. import com.google.common.util.concurrent.ListenableFuture;
  21. import io.milvus.client.*;
  22. import org.testcontainers.containers.GenericContainer;
  23. import java.util.ArrayList;
  24. import java.util.List;
  25. import java.util.Map;
  26. import java.util.SplittableRandom;
  27. import java.util.stream.Collectors;
  28. import java.util.stream.DoubleStream;
  29. import java.util.stream.LongStream;
  30. // This is a simple example demonstrating how to use Milvus Java SDK v0.9.0.
  31. // For detailed API documentation, please refer to
  32. // https://milvus-io.github.io/milvus-sdk-java/javadoc/io/milvus/client/package-summary.html
  33. // You can also find more information on https://milvus.io/docs/overview.md
  34. public class MilvusClientExample {
  35. // Helper function that generates random vectors
  36. static List<List<Float>> generateVectors(int vectorCount, int dimension) {
  37. SplittableRandom splitCollectionRandom = new SplittableRandom();
  38. List<List<Float>> vectors = new ArrayList<>(vectorCount);
  39. for (int i = 0; i < vectorCount; ++i) {
  40. splitCollectionRandom = splitCollectionRandom.split();
  41. DoubleStream doubleStream = splitCollectionRandom.doubles(dimension);
  42. List<Float> vector =
  43. doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
  44. vectors.add(vector);
  45. }
  46. return vectors;
  47. }
  48. // Helper function that normalizes a vector if you are using IP (Inner Product) as your metric
  49. // type
  50. static List<Float> normalizeVector(List<Float> vector) {
  51. float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
  52. final float norm = (float) Math.sqrt(squareSum);
  53. vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
  54. return vector;
  55. }
  56. public static void main(String[] args) throws InterruptedException {
  57. String dockerImage = System.getProperty("docker_image_name", "milvusdb/milvus:0.11.0-cpu");
  58. try (GenericContainer milvusContainer = new GenericContainer(dockerImage).withExposedPorts(19530)) {
  59. milvusContainer.start();
  60. ConnectParam connectParam = new ConnectParam.Builder()
  61. .withHost("localhost")
  62. .withPort(milvusContainer.getFirstMappedPort())
  63. .build();
  64. run(connectParam);
  65. }
  66. }
  67. public static void run(ConnectParam connectParam) {
  68. // Create Milvus client
  69. MilvusClient client = new MilvusGrpcClient(connectParam).withLogging();
  70. // Create a collection with the following collection mapping
  71. final String collectionName = "example"; // collection name
  72. final int dimension = 128; // dimension of each vector
  73. // we choose IP (Inner Product) as our metric type
  74. CollectionMapping collectionMapping = CollectionMapping
  75. .create(collectionName)
  76. .addField("int64", DataType.INT64)
  77. .addField("float", DataType.FLOAT)
  78. .addVectorField("float_vec", DataType.VECTOR_FLOAT, dimension)
  79. .setParamsInJson("{\"segment_row_limit\": 50000, \"auto_id\": true}");
  80. client.createCollection(collectionMapping);
  81. if (!client.hasCollection(collectionName)) {
  82. throw new AssertionError("Collection not found");
  83. }
  84. System.out.println(collectionMapping.toString());
  85. // Get collection info
  86. CollectionMapping collectionInfo = client.getCollectionInfo(collectionName);
  87. // Insert randomly generated field values to collection
  88. final int vectorCount = 100000;
  89. List<Long> longValues = LongStream.range(0, vectorCount).boxed().collect(Collectors.toList());
  90. List<Float> floatValues = LongStream.range(0, vectorCount).boxed().map(Long::floatValue).collect(Collectors.toList());
  91. List<List<Float>> vectors = generateVectors(vectorCount, dimension).stream()
  92. .map(MilvusClientExample::normalizeVector)
  93. .collect(Collectors.toList());
  94. InsertParam insertParam = InsertParam
  95. .create(collectionName)
  96. .addField("int64", DataType.INT64, longValues)
  97. .addField("float", DataType.FLOAT, floatValues)
  98. .addVectorField("float_vec", DataType.VECTOR_FLOAT, vectors);
  99. // Insert returns a list of entity ids that you will be using (if you did not supply them
  100. // yourself) to reference the entities you just inserted
  101. List<Long> vectorIds = client.insert(insertParam);
  102. // Flush data in collection
  103. client.flush(collectionName);
  104. // Get current entity count of collection
  105. long entityCount = client.countEntities(collectionName);
  106. // Create index for the collection
  107. // We choose IVF_SQ8 as our index type here. Refer to Milvus documentation for a
  108. // complete explanation of different index types and their relative parameters.
  109. Index index = Index
  110. .create(collectionName, "float_vec")
  111. .setIndexType(IndexType.IVF_SQ8)
  112. .setMetricType(MetricType.L2)
  113. .setParamsInJson(new JsonBuilder().param("nlist", 2048).build());
  114. client.createIndex(index);
  115. // Get collection info
  116. String collectionStats = client.getCollectionStats(collectionName);
  117. System.out.format("Collection Stats: %s\n", collectionStats);
  118. // Check whether a partition exists in collection
  119. // Obviously we do not have partition "tag" now
  120. if (client.hasPartition(collectionName, "tag")) {
  121. throw new AssertionError("Unexpected partition found!");
  122. }
  123. // Search entities using DSL statement.
  124. // Searching the first 5 entities we just inserted by including them in DSL.
  125. final int searchBatchSize = 5;
  126. List<List<Float>> vectorsToSearch = vectors.subList(0, searchBatchSize);
  127. final long topK = 10;
  128. // Based on the index you created, the available search parameters will be different. Refer to
  129. // the Milvus documentation for how to set the optimal parameters based on your needs.
  130. String dsl = String.format(
  131. "{\"bool\": {"
  132. + "\"must\": [{"
  133. + " \"range\": {"
  134. + " \"float\": {\"GT\": -10, \"LT\": 100}"
  135. + " }},{"
  136. + " \"vector\": {"
  137. + " \"float_vec\": {"
  138. + " \"topk\": %d, \"metric_type\": \"IP\", \"type\": \"float\", \"query\": "
  139. + "%s, \"params\": {\"nprobe\": 50}"
  140. + " }}}]}}",
  141. topK, vectorsToSearch.toString());
  142. SearchParam searchParam = SearchParam
  143. .create(collectionName)
  144. .setDsl(dsl)
  145. .setParamsInJson("{\"fields\": [\"int64\", \"float\"]}");
  146. SearchResult searchResult = client.search(searchParam);
  147. List<List<SearchResult.QueryResult>> queryResultsList = searchResult.getQueryResultsList();
  148. final double epsilon = 0.01;
  149. for (int i = 0; i < searchBatchSize; i++) {
  150. // Since we are searching for vector that is already present in the collection,
  151. // the first result vector should be itself and the distance (inner product) should be
  152. // very close to 1 (some precision is lost during the process)
  153. SearchResult.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
  154. if (firstQueryResult.getEntityId() != vectorIds.get(i)
  155. || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
  156. throw new AssertionError("Wrong results!");
  157. }
  158. }
  159. // You can also get result ids and distances separately
  160. List<List<Long>> resultIds = searchResult.getResultIdsList();
  161. List<List<Float>> resultDistances = searchResult.getResultDistancesList();
  162. // You can send search request asynchronously, which returns a ListenableFuture object
  163. ListenableFuture<SearchResult> searchResponseFuture = client.searchAsync(searchParam);
  164. // Get search response immediately. Obviously you will want to do more complicated stuff with
  165. // ListenableFuture
  166. Futures.getUnchecked(searchResponseFuture);
  167. // Delete the first 5 entities you just searched
  168. client.deleteEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
  169. client.flush(collectionName);
  170. // After deleting them, we call getEntityByID and obviously all 5 entities should not be returned.
  171. Map<Long, Map<String, Object>> entities = client.getEntityByID(collectionName, vectorIds.subList(0, searchBatchSize));
  172. if (!entities.isEmpty()) {
  173. throw new AssertionError("Unexpected entity count!");
  174. }
  175. // Compact the collection, erase deleted data from disk and rebuild index in background (if
  176. // the data size after compaction is still larger than indexFileSize). Data was only
  177. // soft-deleted until you call compact.
  178. client.compact(CompactParam.create(collectionName).setThreshold(0.2));
  179. // Drop index for the collection
  180. client.dropIndex(collectionName, "float_vec");
  181. // Drop collection
  182. client.dropCollection(collectionName);
  183. }
  184. }