Int8VectorExample.java 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package io.milvus.v2;
  2. import com.google.gson.Gson;
  3. import com.google.gson.JsonObject;
  4. import io.milvus.v1.CommonUtils;
  5. import io.milvus.v2.client.ConnectConfig;
  6. import io.milvus.v2.client.MilvusClientV2;
  7. import io.milvus.v2.common.ConsistencyLevel;
  8. import io.milvus.v2.common.DataType;
  9. import io.milvus.v2.common.IndexParam;
  10. import io.milvus.v2.service.collection.request.AddFieldReq;
  11. import io.milvus.v2.service.collection.request.CreateCollectionReq;
  12. import io.milvus.v2.service.collection.request.DropCollectionReq;
  13. import io.milvus.v2.service.vector.request.InsertReq;
  14. import io.milvus.v2.service.vector.request.QueryReq;
  15. import io.milvus.v2.service.vector.request.SearchReq;
  16. import io.milvus.v2.service.vector.request.data.BinaryVec;
  17. import io.milvus.v2.service.vector.request.data.Int8Vec;
  18. import io.milvus.v2.service.vector.response.QueryResp;
  19. import io.milvus.v2.service.vector.response.SearchResp;
  20. import java.nio.ByteBuffer;
  21. import java.util.*;
  22. public class Int8VectorExample {
  23. private static final String COLLECTION_NAME = "java_sdk_example_int8_vector_v2";
  24. private static final String ID_FIELD = "id";
  25. private static final String VECTOR_FIELD = "vector";
  26. private static final Integer VECTOR_DIM = 128;
  27. private static List<ByteBuffer> generateInt8Vectors(int count) {
  28. Random RANDOM = new Random();
  29. List<ByteBuffer> vectors = new ArrayList<>();
  30. for (int i = 0; i < count; i++) {
  31. ByteBuffer vector = ByteBuffer.allocate(VECTOR_DIM);
  32. for (int k = 0; k < VECTOR_DIM; ++k) {
  33. vector.put((byte) (RANDOM.nextInt(256) - 128));
  34. }
  35. vectors.add(vector);
  36. }
  37. return vectors;
  38. }
  39. public static void main(String[] args) {
  40. ConnectConfig config = ConnectConfig.builder()
  41. .uri("http://localhost:19530")
  42. .build();
  43. MilvusClientV2 client = new MilvusClientV2(config);
  44. // Drop collection if exists
  45. client.dropCollection(DropCollectionReq.builder()
  46. .collectionName(COLLECTION_NAME)
  47. .build());
  48. // Create collection
  49. CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
  50. .build();
  51. collectionSchema.addField(AddFieldReq.builder()
  52. .fieldName(ID_FIELD)
  53. .dataType(DataType.Int64)
  54. .isPrimaryKey(Boolean.TRUE)
  55. .build());
  56. collectionSchema.addField(AddFieldReq.builder()
  57. .fieldName(VECTOR_FIELD)
  58. .dataType(DataType.Int8Vector)
  59. .dimension(VECTOR_DIM)
  60. .build());
  61. List<IndexParam> indexes = new ArrayList<>();
  62. Map<String,Object> extraParams = new HashMap<>();
  63. extraParams.put("M", 64);
  64. extraParams.put("efConstruction", 200);
  65. indexes.add(IndexParam.builder()
  66. .fieldName(VECTOR_FIELD)
  67. .indexType(IndexParam.IndexType.HNSW)
  68. .metricType(IndexParam.MetricType.L2)
  69. .extraParams(extraParams)
  70. .build());
  71. CreateCollectionReq requestCreate = CreateCollectionReq.builder()
  72. .collectionName(COLLECTION_NAME)
  73. .collectionSchema(collectionSchema)
  74. .indexParams(indexes)
  75. .consistencyLevel(ConsistencyLevel.BOUNDED)
  76. .build();
  77. client.createCollection(requestCreate);
  78. System.out.println("Collection created");
  79. // Insert entities by rows
  80. int rowCount = 10000;
  81. List<ByteBuffer> vectors = generateInt8Vectors(rowCount);
  82. List<JsonObject> rows = new ArrayList<>();
  83. Gson gson = new Gson();
  84. for (long i = 0L; i < rowCount; ++i) {
  85. JsonObject row = new JsonObject();
  86. row.addProperty(ID_FIELD, i);
  87. ByteBuffer vector = vectors.get((int)i);
  88. row.add(VECTOR_FIELD, gson.toJsonTree(vector.array()));
  89. rows.add(row);
  90. }
  91. client.insert(InsertReq.builder()
  92. .collectionName(COLLECTION_NAME)
  93. .data(rows)
  94. .build());
  95. // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
  96. QueryResp countR = client.query(QueryReq.builder()
  97. .collectionName(COLLECTION_NAME)
  98. .filter("")
  99. .outputFields(Collections.singletonList("count(*)"))
  100. .consistencyLevel(ConsistencyLevel.STRONG)
  101. .build());
  102. System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
  103. // Pick some vectors from the inserted vectors to search
  104. // Ensure the returned top1 item's ID should be equal to target vector's ID
  105. for (int i = 0; i < 10; i++) {
  106. Random ran = new Random();
  107. int k = ran.nextInt(rowCount);
  108. ByteBuffer targetVector = vectors.get(k);
  109. SearchResp searchResp = client.search(SearchReq.builder()
  110. .collectionName(COLLECTION_NAME)
  111. .data(Collections.singletonList(new Int8Vec(targetVector)))
  112. .annsField(VECTOR_FIELD)
  113. .outputFields(Collections.singletonList(VECTOR_FIELD))
  114. .topK(3)
  115. .build());
  116. // The search() allows multiple target vectors to search in a batch.
  117. // Here we only input one vector to search, get the result of No.0 vector to check
  118. List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
  119. List<SearchResp.SearchResult> results = searchResults.get(0);
  120. System.out.printf("\nThe result of No.%d vector %s:\n", k, Arrays.toString(targetVector.array()));
  121. for (SearchResp.SearchResult result : results) {
  122. System.out.printf("ID: %d, Score: %f, Vector: ", (long)result.getId(), result.getScore());
  123. ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
  124. System.out.print(Arrays.toString(vector.array()));
  125. System.out.println();
  126. }
  127. SearchResp.SearchResult firstResult = results.get(0);
  128. if ((long)firstResult.getId() != k) {
  129. throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
  130. (long)firstResult.getId(), k));
  131. }
  132. }
  133. System.out.println("Search result is correct");
  134. // Retrieve some data
  135. int n = 99;
  136. QueryResp queryResp = client.query(QueryReq.builder()
  137. .collectionName(COLLECTION_NAME)
  138. .filter(String.format("id == %d", n))
  139. .outputFields(Collections.singletonList(VECTOR_FIELD))
  140. .build());
  141. List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
  142. if (queryResults.isEmpty()) {
  143. throw new RuntimeException("The query result is empty");
  144. } else {
  145. ByteBuffer vector = (ByteBuffer) queryResults.get(0).getEntity().get(VECTOR_FIELD);
  146. if (vector.compareTo(vectors.get(n)) != 0) {
  147. throw new RuntimeException("The query result is incorrect");
  148. }
  149. }
  150. System.out.println("Query result is correct");
  151. // Drop the collection if you don't need the collection anymore
  152. client.dropCollection(DropCollectionReq.builder()
  153. .collectionName(COLLECTION_NAME)
  154. .build());
  155. client.close();
  156. }
  157. }