BinaryVectorExample.java 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v2;
  20. import com.google.gson.Gson;
  21. import com.google.gson.JsonObject;
  22. import io.milvus.v1.CommonUtils;
  23. import io.milvus.v2.client.ConnectConfig;
  24. import io.milvus.v2.client.MilvusClientV2;
  25. import io.milvus.v2.common.ConsistencyLevel;
  26. import io.milvus.v2.common.DataType;
  27. import io.milvus.v2.common.IndexParam;
  28. import io.milvus.v2.service.collection.request.AddFieldReq;
  29. import io.milvus.v2.service.collection.request.CreateCollectionReq;
  30. import io.milvus.v2.service.collection.request.DropCollectionReq;
  31. import io.milvus.v2.service.vector.request.InsertReq;
  32. import io.milvus.v2.service.vector.request.QueryReq;
  33. import io.milvus.v2.service.vector.request.SearchReq;
  34. import io.milvus.v2.service.vector.request.data.BinaryVec;
  35. import io.milvus.v2.service.vector.response.QueryResp;
  36. import io.milvus.v2.service.vector.response.SearchResp;
  37. import java.nio.ByteBuffer;
  38. import java.util.*;
  39. public class BinaryVectorExample {
  40. private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v2";
  41. private static final String ID_FIELD = "id";
  42. private static final String VECTOR_FIELD = "vector";
  43. private static final Integer VECTOR_DIM = 512;
  44. public static void main(String[] args) {
  45. ConnectConfig config = ConnectConfig.builder()
  46. .uri("http://localhost:19530")
  47. .build();
  48. MilvusClientV2 client = new MilvusClientV2(config);
  49. // Drop collection if exists
  50. client.dropCollection(DropCollectionReq.builder()
  51. .collectionName(COLLECTION_NAME)
  52. .build());
  53. // Create collection
  54. CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
  55. .build();
  56. collectionSchema.addField(AddFieldReq.builder()
  57. .fieldName(ID_FIELD)
  58. .dataType(DataType.Int64)
  59. .isPrimaryKey(Boolean.TRUE)
  60. .build());
  61. collectionSchema.addField(AddFieldReq.builder()
  62. .fieldName(VECTOR_FIELD)
  63. .dataType(DataType.BinaryVector)
  64. .dimension(VECTOR_DIM)
  65. .build());
  66. List<IndexParam> indexes = new ArrayList<>();
  67. Map<String,Object> extraParams = new HashMap<>();
  68. extraParams.put("nlist",64);
  69. indexes.add(IndexParam.builder()
  70. .fieldName(VECTOR_FIELD)
  71. .indexType(IndexParam.IndexType.BIN_IVF_FLAT)
  72. .metricType(IndexParam.MetricType.HAMMING)
  73. .extraParams(extraParams)
  74. .build());
  75. CreateCollectionReq requestCreate = CreateCollectionReq.builder()
  76. .collectionName(COLLECTION_NAME)
  77. .collectionSchema(collectionSchema)
  78. .indexParams(indexes)
  79. .consistencyLevel(ConsistencyLevel.BOUNDED)
  80. .build();
  81. client.createCollection(requestCreate);
  82. System.out.println("Collection created");
  83. // Insert entities by rows
  84. int rowCount = 10000;
  85. List<JsonObject> rows = new ArrayList<>();
  86. Gson gson = new Gson();
  87. List<ByteBuffer> vectors = new ArrayList<>();
  88. for (long i = 0L; i < rowCount; ++i) {
  89. JsonObject row = new JsonObject();
  90. row.addProperty(ID_FIELD, i);
  91. ByteBuffer vector = CommonUtils.generateBinaryVector(VECTOR_DIM);
  92. vectors.add(vector);
  93. row.add(VECTOR_FIELD, gson.toJsonTree(vector.array()));
  94. rows.add(row);
  95. }
  96. client.insert(InsertReq.builder()
  97. .collectionName(COLLECTION_NAME)
  98. .data(rows)
  99. .build());
  100. // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
  101. QueryResp countR = client.query(QueryReq.builder()
  102. .collectionName(COLLECTION_NAME)
  103. .filter("")
  104. .outputFields(Collections.singletonList("count(*)"))
  105. .consistencyLevel(ConsistencyLevel.STRONG)
  106. .build());
  107. System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
  108. // Pick some vectors from the inserted vectors to search
  109. // Ensure the returned top1 item's ID should be equal to target vector's ID
  110. for (int i = 0; i < 10; i++) {
  111. Random ran = new Random();
  112. int k = ran.nextInt(rowCount);
  113. ByteBuffer targetVector = vectors.get(k);
  114. Map<String,Object> params = new HashMap<>();
  115. params.put("nprobe",16);
  116. SearchResp searchResp = client.search(SearchReq.builder()
  117. .collectionName(COLLECTION_NAME)
  118. .data(Collections.singletonList(new BinaryVec(targetVector)))
  119. .annsField(VECTOR_FIELD)
  120. .outputFields(Collections.singletonList(VECTOR_FIELD))
  121. .searchParams(params)
  122. .topK(3)
  123. .build());
  124. // The search() allows multiple target vectors to search in a batch.
  125. // Here we only input one vector to search, get the result of No.0 vector to check
  126. List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
  127. List<SearchResp.SearchResult> results = searchResults.get(0);
  128. System.out.printf("The result of No.%d target vector:\n", i);
  129. for (SearchResp.SearchResult result : results) {
  130. System.out.println(result.getEntity());
  131. System.out.printf("ID: %d, Score: %f, Vector: ", result.getId(), result.getScore());
  132. ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
  133. vector.rewind();
  134. while (vector.hasRemaining()) {
  135. System.out.print(Integer.toBinaryString(vector.get()));
  136. }
  137. System.out.println();
  138. }
  139. SearchResp.SearchResult firstResult = results.get(0);
  140. if ((long)firstResult.getId() != k) {
  141. throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
  142. firstResult.getId(), k));
  143. }
  144. }
  145. System.out.println("Search result is correct");
  146. // Retrieve some data
  147. int n = 99;
  148. QueryResp queryResp = client.query(QueryReq.builder()
  149. .collectionName(COLLECTION_NAME)
  150. .filter(String.format("id == %d", n))
  151. .outputFields(Collections.singletonList(VECTOR_FIELD))
  152. .build());
  153. List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
  154. if (queryResults.isEmpty()) {
  155. throw new RuntimeException("The query result is empty");
  156. } else {
  157. ByteBuffer vector = (ByteBuffer) queryResults.get(0).getEntity().get(VECTOR_FIELD);
  158. if (vector.compareTo(vectors.get(n)) != 0) {
  159. throw new RuntimeException("The query result is incorrect");
  160. }
  161. }
  162. System.out.println("Query result is correct");
  163. // Drop the collection if you don't need the collection anymore
  164. client.dropCollection(DropCollectionReq.builder()
  165. .collectionName(COLLECTION_NAME)
  166. .build());
  167. client.close();
  168. }
  169. }