Float16VectorExample.java 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package io.milvus.v2;
  2. import com.google.gson.Gson;
  3. import com.google.gson.JsonObject;
  4. import io.milvus.v1.CommonUtils;
  5. import io.milvus.v2.client.ConnectConfig;
  6. import io.milvus.v2.client.MilvusClientV2;
  7. import io.milvus.v2.common.ConsistencyLevel;
  8. import io.milvus.v2.common.DataType;
  9. import io.milvus.v2.common.IndexParam;
  10. import io.milvus.v2.service.collection.request.AddFieldReq;
  11. import io.milvus.v2.service.collection.request.CreateCollectionReq;
  12. import io.milvus.v2.service.collection.request.DropCollectionReq;
  13. import io.milvus.v2.service.collection.request.HasCollectionReq;
  14. import io.milvus.v2.service.vector.request.InsertReq;
  15. import io.milvus.v2.service.vector.request.QueryReq;
  16. import io.milvus.v2.service.vector.request.SearchReq;
  17. import io.milvus.v2.service.vector.request.data.BFloat16Vec;
  18. import io.milvus.v2.service.vector.request.data.BaseVector;
  19. import io.milvus.v2.service.vector.request.data.Float16Vec;
  20. import io.milvus.v2.service.vector.response.InsertResp;
  21. import io.milvus.v2.service.vector.response.QueryResp;
  22. import io.milvus.v2.service.vector.response.SearchResp;
  23. import java.nio.ByteBuffer;
  24. import java.util.*;
  25. public class Float16VectorExample {
  26. private static final String COLLECTION_NAME = "java_sdk_example_float16_vector_v2";
  27. private static final String ID_FIELD = "id";
  28. private static final String FP16_VECTOR_FIELD = "fp16_vector";
  29. private static final String BF16_VECTOR_FIELD = "bf16_vector";
  30. private static final Integer VECTOR_DIM = 128;
  31. private static final MilvusClientV2 milvusClient;
  32. static {
  33. milvusClient = new MilvusClientV2(ConnectConfig.builder()
  34. .uri("http://localhost:19530")
  35. .build());
  36. }
  37. private static void createCollection() {
  38. // drop the collection if you don't need the collection anymore
  39. Boolean has = milvusClient.hasCollection(HasCollectionReq.builder()
  40. .collectionName(COLLECTION_NAME)
  41. .build());
  42. if (has) {
  43. dropCollection();
  44. }
  45. // build a collection with two vector fields
  46. CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
  47. .build();
  48. collectionSchema.addField(AddFieldReq.builder()
  49. .fieldName(ID_FIELD)
  50. .dataType(DataType.Int64)
  51. .isPrimaryKey(Boolean.TRUE)
  52. .build());
  53. collectionSchema.addField(AddFieldReq.builder()
  54. .fieldName(FP16_VECTOR_FIELD)
  55. .dataType(io.milvus.v2.common.DataType.Float16Vector)
  56. .dimension(VECTOR_DIM)
  57. .build());
  58. collectionSchema.addField(AddFieldReq.builder()
  59. .fieldName(BF16_VECTOR_FIELD)
  60. .dataType(io.milvus.v2.common.DataType.BFloat16Vector)
  61. .dimension(VECTOR_DIM)
  62. .build());
  63. List<IndexParam> indexes = new ArrayList<>();
  64. Map<String,Object> extraParams = new HashMap<>();
  65. extraParams.put("nlist",64);
  66. indexes.add(IndexParam.builder()
  67. .fieldName(FP16_VECTOR_FIELD)
  68. .indexType(IndexParam.IndexType.IVF_FLAT)
  69. .metricType(IndexParam.MetricType.COSINE)
  70. .extraParams(extraParams)
  71. .build());
  72. indexes.add(IndexParam.builder()
  73. .fieldName(BF16_VECTOR_FIELD)
  74. .indexType(IndexParam.IndexType.FLAT)
  75. .metricType(IndexParam.MetricType.COSINE)
  76. .build());
  77. CreateCollectionReq requestCreate = CreateCollectionReq.builder()
  78. .collectionName(COLLECTION_NAME)
  79. .collectionSchema(collectionSchema)
  80. .indexParams(indexes)
  81. .consistencyLevel(ConsistencyLevel.BOUNDED)
  82. .build();
  83. milvusClient.createCollection(requestCreate);
  84. }
  85. private static void prepareData(int count) {
  86. List<JsonObject> rows = new ArrayList<>();
  87. Gson gson = new Gson();
  88. for (long i = 0; i < count; i++) {
  89. JsonObject row = new JsonObject();
  90. row.addProperty(ID_FIELD, i);
  91. ByteBuffer buf1 = CommonUtils.generateFloat16Vector(VECTOR_DIM, false);
  92. row.add(FP16_VECTOR_FIELD, gson.toJsonTree(buf1.array()));
  93. ByteBuffer buf2 = CommonUtils.generateFloat16Vector(VECTOR_DIM, true);
  94. row.add(BF16_VECTOR_FIELD, gson.toJsonTree(buf1.array()));
  95. rows.add(row);
  96. }
  97. InsertResp insertResp = milvusClient.insert(InsertReq.builder()
  98. .collectionName(COLLECTION_NAME)
  99. .data(rows)
  100. .build());
  101. System.out.println(insertResp.getInsertCnt() + " rows inserted");
  102. }
  103. private static void searchVectors(List<Long> taargetIDs, List<BaseVector> targetVectors, String vectorFieldName) {
  104. int topK = 5;
  105. SearchResp searchResp = milvusClient.search(SearchReq.builder()
  106. .collectionName(COLLECTION_NAME)
  107. .data(targetVectors)
  108. .annsField(vectorFieldName)
  109. .topK(topK)
  110. .outputFields(Collections.singletonList(vectorFieldName))
  111. .consistencyLevel(ConsistencyLevel.BOUNDED)
  112. .build());
  113. List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
  114. if (searchResults.isEmpty()) {
  115. throw new RuntimeException("The search result is empty");
  116. }
  117. for (int i = 0; i < searchResults.size(); i++) {
  118. List<SearchResp.SearchResult> results = searchResults.get(i);
  119. if (results.size() != topK) {
  120. throw new RuntimeException(String.format("The search result should contains top%d items", topK));
  121. }
  122. SearchResp.SearchResult topResult = results.get(0);
  123. long id = (long) topResult.getId();
  124. if (id != taargetIDs.get(i)) {
  125. throw new RuntimeException("The top1 id is incorrect");
  126. }
  127. Map<String, Object> entity = topResult.getEntity();
  128. ByteBuffer vectorBuf = (ByteBuffer) entity.get(vectorFieldName);
  129. if (!vectorBuf.equals(targetVectors.get(i).getData())) {
  130. throw new RuntimeException("The top1 output vector is incorrect");
  131. }
  132. System.out.println(results.get(0));
  133. }
  134. System.out.println("Search result of float16 vector is correct");
  135. }
  136. private static void search() {
  137. // retrieve some rows for search
  138. List<Long> targetIDs = Arrays.asList(999L, 2024L);
  139. QueryResp queryResp = milvusClient.query(QueryReq.builder()
  140. .collectionName(COLLECTION_NAME)
  141. .filter(ID_FIELD + " in " + targetIDs)
  142. .outputFields(Arrays.asList(FP16_VECTOR_FIELD, BF16_VECTOR_FIELD))
  143. .consistencyLevel(ConsistencyLevel.STRONG)
  144. .build());
  145. List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
  146. if (queryResults.isEmpty()) {
  147. throw new RuntimeException("The query result is empty");
  148. }
  149. List<BaseVector> targetFP16Vectors = new ArrayList<>();
  150. List<BaseVector> targetBF16Vectors = new ArrayList<>();
  151. for (QueryResp.QueryResult queryResult : queryResults) {
  152. Map<String, Object> entity = queryResult.getEntity();
  153. ByteBuffer f16VectorBuf = (ByteBuffer) entity.get(FP16_VECTOR_FIELD);
  154. targetFP16Vectors.add(new Float16Vec(f16VectorBuf));
  155. ByteBuffer bf16VectorBuf = (ByteBuffer) entity.get(BF16_VECTOR_FIELD);
  156. targetBF16Vectors.add(new BFloat16Vec(bf16VectorBuf));
  157. }
  158. // search float16 vector
  159. searchVectors(targetIDs, targetFP16Vectors, FP16_VECTOR_FIELD);
  160. // search bfloat16 vector
  161. searchVectors(targetIDs, targetBF16Vectors, BF16_VECTOR_FIELD);
  162. }
  163. private static void dropCollection() {
  164. milvusClient.dropCollection(DropCollectionReq.builder()
  165. .collectionName(COLLECTION_NAME)
  166. .build());
  167. System.out.println("Collection dropped");
  168. }
  169. public static void main(String[] args) {
  170. createCollection();
  171. prepareData(10000);
  172. search();
  173. dropCollection();
  174. }
  175. }