Float16VectorExample.java 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v2;
  20. import com.google.gson.Gson;
  21. import com.google.gson.JsonObject;
  22. import io.milvus.v1.CommonUtils;
  23. import io.milvus.v2.client.ConnectConfig;
  24. import io.milvus.v2.client.MilvusClientV2;
  25. import io.milvus.v2.common.ConsistencyLevel;
  26. import io.milvus.v2.common.DataType;
  27. import io.milvus.v2.common.IndexParam;
  28. import io.milvus.v2.service.collection.request.AddFieldReq;
  29. import io.milvus.v2.service.collection.request.CreateCollectionReq;
  30. import io.milvus.v2.service.collection.request.DropCollectionReq;
  31. import io.milvus.v2.service.collection.request.HasCollectionReq;
  32. import io.milvus.v2.service.vector.request.InsertReq;
  33. import io.milvus.v2.service.vector.request.QueryReq;
  34. import io.milvus.v2.service.vector.request.SearchReq;
  35. import io.milvus.v2.service.vector.request.data.BFloat16Vec;
  36. import io.milvus.v2.service.vector.request.data.BaseVector;
  37. import io.milvus.v2.service.vector.request.data.Float16Vec;
  38. import io.milvus.v2.service.vector.response.InsertResp;
  39. import io.milvus.v2.service.vector.response.QueryResp;
  40. import io.milvus.v2.service.vector.response.SearchResp;
  41. import java.nio.ByteBuffer;
  42. import java.util.*;
  43. public class Float16VectorExample {
  44. private static final String COLLECTION_NAME = "java_sdk_example_float16_vector_v2";
  45. private static final String ID_FIELD = "id";
  46. private static final String FP16_VECTOR_FIELD = "fp16_vector";
  47. private static final String BF16_VECTOR_FIELD = "bf16_vector";
  48. private static final Integer VECTOR_DIM = 128;
  49. private static final MilvusClientV2 client;
  50. static {
  51. client = new MilvusClientV2(ConnectConfig.builder()
  52. .uri("http://localhost:19530")
  53. .build());
  54. }
  55. private static void createCollection() {
  56. // Drop the collection if you don't need the collection anymore
  57. Boolean has = client.hasCollection(HasCollectionReq.builder()
  58. .collectionName(COLLECTION_NAME)
  59. .build());
  60. if (has) {
  61. dropCollection();
  62. }
  63. // Build a collection with two vector fields
  64. CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
  65. .build();
  66. collectionSchema.addField(AddFieldReq.builder()
  67. .fieldName(ID_FIELD)
  68. .dataType(DataType.Int64)
  69. .isPrimaryKey(Boolean.TRUE)
  70. .build());
  71. collectionSchema.addField(AddFieldReq.builder()
  72. .fieldName(FP16_VECTOR_FIELD)
  73. .dataType(io.milvus.v2.common.DataType.Float16Vector)
  74. .dimension(VECTOR_DIM)
  75. .build());
  76. collectionSchema.addField(AddFieldReq.builder()
  77. .fieldName(BF16_VECTOR_FIELD)
  78. .dataType(io.milvus.v2.common.DataType.BFloat16Vector)
  79. .dimension(VECTOR_DIM)
  80. .build());
  81. List<IndexParam> indexes = new ArrayList<>();
  82. Map<String,Object> extraParams = new HashMap<>();
  83. extraParams.put("nlist",64);
  84. indexes.add(IndexParam.builder()
  85. .fieldName(FP16_VECTOR_FIELD)
  86. .indexType(IndexParam.IndexType.IVF_FLAT)
  87. .metricType(IndexParam.MetricType.COSINE)
  88. .extraParams(extraParams)
  89. .build());
  90. indexes.add(IndexParam.builder()
  91. .fieldName(BF16_VECTOR_FIELD)
  92. .indexType(IndexParam.IndexType.FLAT)
  93. .metricType(IndexParam.MetricType.COSINE)
  94. .build());
  95. CreateCollectionReq requestCreate = CreateCollectionReq.builder()
  96. .collectionName(COLLECTION_NAME)
  97. .collectionSchema(collectionSchema)
  98. .indexParams(indexes)
  99. .consistencyLevel(ConsistencyLevel.BOUNDED)
  100. .build();
  101. client.createCollection(requestCreate);
  102. }
  103. private static void prepareData(int count) {
  104. List<JsonObject> rows = new ArrayList<>();
  105. Gson gson = new Gson();
  106. for (long i = 0; i < count; i++) {
  107. JsonObject row = new JsonObject();
  108. row.addProperty(ID_FIELD, i);
  109. ByteBuffer buf1 = CommonUtils.generateFloat16Vector(VECTOR_DIM, false);
  110. row.add(FP16_VECTOR_FIELD, gson.toJsonTree(buf1.array()));
  111. ByteBuffer buf2 = CommonUtils.generateFloat16Vector(VECTOR_DIM, true);
  112. row.add(BF16_VECTOR_FIELD, gson.toJsonTree(buf1.array()));
  113. rows.add(row);
  114. }
  115. InsertResp insertResp = client.insert(InsertReq.builder()
  116. .collectionName(COLLECTION_NAME)
  117. .data(rows)
  118. .build());
  119. System.out.println(insertResp.getInsertCnt() + " rows inserted");
  120. }
  121. private static void searchVectors(List<Long> taargetIDs, List<BaseVector> targetVectors, String vectorFieldName) {
  122. int topK = 5;
  123. SearchResp searchResp = client.search(SearchReq.builder()
  124. .collectionName(COLLECTION_NAME)
  125. .data(targetVectors)
  126. .annsField(vectorFieldName)
  127. .topK(topK)
  128. .outputFields(Collections.singletonList(vectorFieldName))
  129. .consistencyLevel(ConsistencyLevel.BOUNDED)
  130. .build());
  131. List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
  132. if (searchResults.isEmpty()) {
  133. throw new RuntimeException("The search result is empty");
  134. }
  135. for (int i = 0; i < searchResults.size(); i++) {
  136. List<SearchResp.SearchResult> results = searchResults.get(i);
  137. if (results.size() != topK) {
  138. throw new RuntimeException(String.format("The search result should contains top%d items", topK));
  139. }
  140. SearchResp.SearchResult topResult = results.get(0);
  141. long id = (long) topResult.getId();
  142. if (id != taargetIDs.get(i)) {
  143. throw new RuntimeException("The top1 id is incorrect");
  144. }
  145. Map<String, Object> entity = topResult.getEntity();
  146. ByteBuffer vectorBuf = (ByteBuffer) entity.get(vectorFieldName);
  147. if (!vectorBuf.equals(targetVectors.get(i).getData())) {
  148. throw new RuntimeException("The top1 output vector is incorrect");
  149. }
  150. System.out.println(results.get(0));
  151. }
  152. System.out.println("Search result of float16 vector is correct");
  153. }
  154. private static void search() {
  155. // Retrieve some rows for search
  156. List<Long> targetIDs = Arrays.asList(999L, 2024L);
  157. QueryResp queryResp = client.query(QueryReq.builder()
  158. .collectionName(COLLECTION_NAME)
  159. .filter(ID_FIELD + " in " + targetIDs)
  160. .outputFields(Arrays.asList(FP16_VECTOR_FIELD, BF16_VECTOR_FIELD))
  161. .consistencyLevel(ConsistencyLevel.STRONG)
  162. .build());
  163. List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
  164. if (queryResults.isEmpty()) {
  165. throw new RuntimeException("The query result is empty");
  166. }
  167. List<BaseVector> targetFP16Vectors = new ArrayList<>();
  168. List<BaseVector> targetBF16Vectors = new ArrayList<>();
  169. for (QueryResp.QueryResult queryResult : queryResults) {
  170. Map<String, Object> entity = queryResult.getEntity();
  171. ByteBuffer f16VectorBuf = (ByteBuffer) entity.get(FP16_VECTOR_FIELD);
  172. targetFP16Vectors.add(new Float16Vec(f16VectorBuf));
  173. ByteBuffer bf16VectorBuf = (ByteBuffer) entity.get(BF16_VECTOR_FIELD);
  174. targetBF16Vectors.add(new BFloat16Vec(bf16VectorBuf));
  175. }
  176. // Search float16 vector
  177. searchVectors(targetIDs, targetFP16Vectors, FP16_VECTOR_FIELD);
  178. // Search bfloat16 vector
  179. searchVectors(targetIDs, targetBF16Vectors, BF16_VECTOR_FIELD);
  180. }
  181. private static void dropCollection() {
  182. client.dropCollection(DropCollectionReq.builder()
  183. .collectionName(COLLECTION_NAME)
  184. .build());
  185. System.out.println("Collection dropped");
  186. }
  187. public static void main(String[] args) {
  188. createCollection();
  189. prepareData(10000);
  190. search();
  191. dropCollection();
  192. client.close();
  193. }
  194. }