MilvusClientExample.java 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. import io.milvus.client.*;
  20. import java.util.ArrayList;
  21. import java.util.List;
  22. import java.util.SplittableRandom;
  23. import java.util.concurrent.TimeUnit;
  24. import java.util.stream.Collectors;
  25. import java.util.stream.DoubleStream;
  26. public class MilvusClientExample {
  27. // Helper function that generates random vectors
  28. static List<List<Float>> generateVectors(long vectorCount, long dimension) {
  29. SplittableRandom splittableRandom = new SplittableRandom();
  30. List<List<Float>> vectors = new ArrayList<>();
  31. for (int i = 0; i < vectorCount; ++i) {
  32. splittableRandom = splittableRandom.split();
  33. DoubleStream doubleStream = splittableRandom.doubles(dimension);
  34. List<Float> vector =
  35. doubleStream.boxed().map(Double::floatValue).collect(Collectors.toList());
  36. vectors.add(vector);
  37. }
  38. return vectors;
  39. }
  40. // Helper function that normalizes a vector if you are using IP (Inner Product) as your metric
  41. // type
  42. static List<Float> normalizeVector(List<Float> vector) {
  43. float squareSum = vector.stream().map(x -> x * x).reduce((float) 0, Float::sum);
  44. final float norm = (float) Math.sqrt(squareSum);
  45. vector = vector.stream().map(x -> x / norm).collect(Collectors.toList());
  46. return vector;
  47. }
  48. public static void main(String[] args) throws InterruptedException, ConnectFailedException {
  49. // You may need to change the following to the host and port of your Milvus server
  50. final String host = "192.168.1.149";
  51. final String port = "19530";
  52. // Create Milvus client
  53. MilvusClient client = new MilvusGrpcClient();
  54. // Connect to Milvus server
  55. final long waitTime = 1000; // Wait 1000 ms for client to establish a connection
  56. ConnectParam connectParam =
  57. new ConnectParam.Builder().withHost(host).withPort(port).withWaitTime(waitTime).build();
  58. try {
  59. Response connectResponse = client.connect(connectParam);
  60. } catch (ConnectFailedException e) {
  61. System.out.println(e.toString());
  62. throw e;
  63. }
  64. // Check whether we are connected
  65. boolean connected = client.isConnected();
  66. System.out.println("Connected = " + connected);
  67. // Create a table with the following table schema
  68. final String tableName = "example"; // table name
  69. final long dimension = 128; // dimension of each vector
  70. final long indexFileSize = 1024; // maximum size (in MB) of each index file
  71. final MetricType metricType = MetricType.IP; // we choose IP (Inner Product) as our metric type
  72. TableSchema tableSchema =
  73. new TableSchema.Builder(tableName, dimension)
  74. .withIndexFileSize(indexFileSize)
  75. .withMetricType(metricType)
  76. .build();
  77. Response createTableResponse = client.createTable(tableSchema);
  78. System.out.println(createTableResponse);
  79. // Check whether the table exists
  80. HasTableResponse hasTableResponse = client.hasTable(tableName);
  81. System.out.println(hasTableResponse);
  82. // Describe the table
  83. DescribeTableResponse describeTableResponse = client.describeTable(tableName);
  84. System.out.println(describeTableResponse);
  85. // Insert randomly generated vectors to table
  86. final int vectorCount = 100000;
  87. List<List<Float>> vectors = generateVectors(vectorCount, dimension);
  88. vectors =
  89. vectors.stream().map(MilvusClientExample::normalizeVector).collect(Collectors.toList());
  90. InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
  91. InsertResponse insertResponse = client.insert(insertParam);
  92. System.out.println(insertResponse);
  93. // Insert returns a list of vector ids that you will be using (if you did not supply them
  94. // yourself) to reference the vectors you just inserted
  95. List<Long> vectorIds = insertResponse.getVectorIds();
  96. // The data we just inserted won't be serialized and written to meta until the next second
  97. // wait 1 second here
  98. TimeUnit.SECONDS.sleep(1);
  99. // Get current row count of table
  100. GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName);
  101. System.out.println(getTableRowCountResponse);
  102. // Create index for the table
  103. // We choose IVF_SQ8 as our index type here. Refer to IndexType javadoc for a
  104. // complete explanation of different index types
  105. final IndexType indexType = IndexType.IVF_SQ8;
  106. Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8).build();
  107. CreateIndexParam createIndexParam =
  108. new CreateIndexParam.Builder(tableName).withIndex(index).build();
  109. Response createIndexResponse = client.createIndex(createIndexParam);
  110. System.out.println(createIndexResponse);
  111. // Describe the index for your table
  112. DescribeIndexResponse describeIndexResponse = client.describeIndex(tableName);
  113. System.out.println(describeIndexResponse);
  114. // Search vectors
  115. // Searching the first 5 vectors of the vectors we just inserted
  116. final int searchBatchSize = 5;
  117. List<List<Float>> vectorsToSearch = vectors.subList(0, searchBatchSize);
  118. final long topK = 10;
  119. SearchParam searchParam =
  120. new SearchParam.Builder(tableName, vectorsToSearch).withTopK(topK).build();
  121. SearchResponse searchResponse = client.search(searchParam);
  122. System.out.println(searchResponse);
  123. if (searchResponse.getResponse().ok()) {
  124. List<List<SearchResponse.QueryResult>> queryResultsList =
  125. searchResponse.getQueryResultsList();
  126. final double epsilon = 0.001;
  127. for (int i = 0; i < searchBatchSize; i++) {
  128. // Since we are searching for vector that is already present in the table,
  129. // the first result vector should be itself and the distance (inner product) should be
  130. // very close to 1 (some precision is lost during the process)
  131. SearchResponse.QueryResult firstQueryResult = queryResultsList.get(i).get(0);
  132. if (firstQueryResult.getVectorId() != vectorIds.get(i)
  133. || Math.abs(1 - firstQueryResult.getDistance()) > epsilon) {
  134. throw new AssertionError("Wrong results!");
  135. }
  136. }
  137. }
  138. // Drop index for the table
  139. Response dropIndexResponse = client.dropIndex(tableName);
  140. System.out.println(dropIndexResponse);
  141. // Drop table
  142. Response dropTableResponse = client.dropTable(tableName);
  143. System.out.println(dropTableResponse);
  144. // Disconnect from Milvus server
  145. try {
  146. Response disconnectResponse = client.disconnect();
  147. } catch (InterruptedException e) {
  148. System.out.println("Failed to disconnect: " + e.toString());
  149. throw e;
  150. }
  151. return;
  152. }
  153. }