HybridSearchExample.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v2;
  20. import com.google.gson.Gson;
  21. import com.google.gson.JsonObject;
  22. import io.milvus.v2.common.DataType;
  23. import io.milvus.v1.CommonUtils;
  24. import io.milvus.v2.client.ConnectConfig;
  25. import io.milvus.v2.client.MilvusClientV2;
  26. import io.milvus.v2.common.ConsistencyLevel;
  27. import io.milvus.v2.common.IndexParam;
  28. import io.milvus.v2.service.collection.request.AddFieldReq;
  29. import io.milvus.v2.service.collection.request.CreateCollectionReq;
  30. import io.milvus.v2.service.collection.request.DropCollectionReq;
  31. import io.milvus.v2.service.vector.request.AnnSearchReq;
  32. import io.milvus.v2.service.vector.request.HybridSearchReq;
  33. import io.milvus.v2.service.vector.request.InsertReq;
  34. import io.milvus.v2.service.vector.request.QueryReq;
  35. import io.milvus.v2.service.vector.request.data.BaseVector;
  36. import io.milvus.v2.service.vector.request.data.BinaryVec;
  37. import io.milvus.v2.service.vector.request.data.FloatVec;
  38. import io.milvus.v2.service.vector.request.data.SparseFloatVec;
  39. import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
  40. import io.milvus.v2.service.vector.response.QueryResp;
  41. import io.milvus.v2.service.vector.response.SearchResp;
  42. import java.util.*;
  43. public class HybridSearchExample {
  44. private static final MilvusClientV2 client;
  45. static {
  46. ConnectConfig config = ConnectConfig.builder()
  47. .uri("http://localhost:19530")
  48. .build();
  49. client = new MilvusClientV2(config);
  50. }
  51. private static final String COLLECTION_NAME = "java_sdk_example_hybrid_search_v2";
  52. private static final String ID_FIELD = "ID";
  53. private static final String FLOAT_VECTOR_FIELD = "float_vector";
  54. private static final Integer FLOAT_VECTOR_DIM = 128;
  55. private static final IndexParam.MetricType FLOAT_VECTOR_METRIC = IndexParam.MetricType.COSINE;
  56. private static final String BINARY_VECTOR_FIELD = "binary_vector";
  57. private static final Integer BINARY_VECTOR_DIM = 256;
  58. private static final IndexParam.MetricType BINARY_VECTOR_METRIC = IndexParam.MetricType.JACCARD;
  59. private static final String FLOAT16_VECTOR_FIELD = "float16_vector";
  60. private static final Integer FLOAT16_VECTOR_DIM = 256;
  61. private static final IndexParam.MetricType FLOAT16_VECTOR_METRIC = IndexParam.MetricType.L2;
  62. private static final String SPARSE_VECTOR_FIELD = "sparse_vector";
  63. private static final IndexParam.MetricType SPARSE_VECTOR_METRIC = IndexParam.MetricType.IP;
  64. private void createCollection() {
  65. client.dropCollection(DropCollectionReq.builder()
  66. .collectionName(COLLECTION_NAME)
  67. .build());
  68. // Create collection
  69. CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
  70. .build();
  71. collectionSchema.addField(AddFieldReq.builder()
  72. .fieldName(ID_FIELD)
  73. .dataType(DataType.Int64)
  74. .isPrimaryKey(Boolean.TRUE)
  75. .build());
  76. collectionSchema.addField(AddFieldReq.builder()
  77. .fieldName(FLOAT_VECTOR_FIELD)
  78. .dataType(DataType.FloatVector)
  79. .dimension(FLOAT_VECTOR_DIM)
  80. .build());
  81. collectionSchema.addField(AddFieldReq.builder()
  82. .fieldName(BINARY_VECTOR_FIELD)
  83. .dataType(DataType.BinaryVector)
  84. .dimension(BINARY_VECTOR_DIM)
  85. .build());
  86. collectionSchema.addField(AddFieldReq.builder()
  87. .fieldName(FLOAT16_VECTOR_FIELD)
  88. .dataType(DataType.Float16Vector)
  89. .dimension(FLOAT16_VECTOR_DIM)
  90. .build());
  91. collectionSchema.addField(AddFieldReq.builder()
  92. .fieldName(SPARSE_VECTOR_FIELD)
  93. .dataType(DataType.SparseFloatVector)
  94. .build());
  95. List<IndexParam> indexes = new ArrayList<>();
  96. Map<String,Object> fvParams = new HashMap<>();
  97. fvParams.put("nlist",128);
  98. fvParams.put("m",16);
  99. fvParams.put("nbits",8);
  100. indexes.add(IndexParam.builder()
  101. .fieldName(FLOAT_VECTOR_FIELD)
  102. .indexType(IndexParam.IndexType.IVF_PQ)
  103. .extraParams(fvParams)
  104. .metricType(FLOAT_VECTOR_METRIC)
  105. .build());
  106. indexes.add(IndexParam.builder()
  107. .fieldName(BINARY_VECTOR_FIELD)
  108. .indexType(IndexParam.IndexType.BIN_FLAT)
  109. .metricType(BINARY_VECTOR_METRIC)
  110. .build());
  111. Map<String,Object> fv16Params = new HashMap<>();
  112. fv16Params.clear();
  113. fv16Params.put("M",16);
  114. fv16Params.put("efConstruction",64);
  115. indexes.add(IndexParam.builder()
  116. .fieldName(FLOAT16_VECTOR_FIELD)
  117. .indexType(IndexParam.IndexType.HNSW)
  118. .extraParams(fv16Params)
  119. .metricType(FLOAT16_VECTOR_METRIC)
  120. .build());
  121. indexes.add(IndexParam.builder()
  122. .fieldName(SPARSE_VECTOR_FIELD)
  123. .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
  124. .metricType(SPARSE_VECTOR_METRIC)
  125. .build());
  126. CreateCollectionReq requestCreate = CreateCollectionReq.builder()
  127. .collectionName(COLLECTION_NAME)
  128. .collectionSchema(collectionSchema)
  129. .indexParams(indexes)
  130. .consistencyLevel(ConsistencyLevel.BOUNDED)
  131. .build();
  132. client.createCollection(requestCreate);
  133. System.out.println("Collection created");
  134. }
  135. private void insertData() {
  136. long idCount = 0;
  137. int rowCount = 10000;
  138. // Insert entities by rows
  139. List<JsonObject> rows = new ArrayList<>();
  140. Gson gson = new Gson();
  141. for (long i = 1L; i <= rowCount; ++i) {
  142. JsonObject row = new JsonObject();
  143. row.addProperty(ID_FIELD, idCount++);
  144. row.add(FLOAT_VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(FLOAT_VECTOR_DIM)));
  145. row.add(BINARY_VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateBinaryVector(BINARY_VECTOR_DIM).array()));
  146. row.add(FLOAT16_VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloat16Vector(FLOAT16_VECTOR_DIM, false).array()));
  147. row.add(SPARSE_VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateSparseVector()));
  148. rows.add(row);
  149. }
  150. client.insert(InsertReq.builder()
  151. .collectionName(COLLECTION_NAME)
  152. .data(rows)
  153. .build());
  154. System.out.printf("%d entities inserted by rows\n", rowCount);
  155. }
  156. private void hybridSearch() {
  157. // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
  158. QueryResp countR = client.query(QueryReq.builder()
  159. .collectionName(COLLECTION_NAME)
  160. .filter("")
  161. .outputFields(Collections.singletonList("count(*)"))
  162. .consistencyLevel(ConsistencyLevel.STRONG)
  163. .build());
  164. System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
  165. // Search on multiple vector fields
  166. int NQ = 2;
  167. List<BaseVector> floatVectors = new ArrayList<>();
  168. List<BaseVector> binaryVectors = new ArrayList<>();
  169. List<BaseVector> sparseVectors = new ArrayList<>();
  170. for (int i = 0; i < NQ; i++) {
  171. floatVectors.add(new FloatVec(CommonUtils.generateFloatVector(FLOAT_VECTOR_DIM)));
  172. binaryVectors.add(new BinaryVec(CommonUtils.generateBinaryVector(BINARY_VECTOR_DIM)));
  173. sparseVectors.add(new SparseFloatVec(CommonUtils.generateSparseVector()));
  174. }
  175. List<AnnSearchReq> searchRequests = new ArrayList<>();
  176. searchRequests.add(AnnSearchReq.builder()
  177. .vectorFieldName("float_vector")
  178. .vectors(floatVectors)
  179. .params("{\"nprobe\": 10}")
  180. .topK(10)
  181. .build());
  182. searchRequests.add(AnnSearchReq.builder()
  183. .vectorFieldName("binary_vector")
  184. .vectors(binaryVectors)
  185. .topK(50)
  186. .build());
  187. searchRequests.add(AnnSearchReq.builder()
  188. .vectorFieldName("sparse_vector")
  189. .vectors(sparseVectors)
  190. .topK(100)
  191. .build());
  192. HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
  193. .collectionName(COLLECTION_NAME)
  194. .searchRequests(searchRequests)
  195. .ranker(new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f)))
  196. .topK(5)
  197. .consistencyLevel(ConsistencyLevel.BOUNDED)
  198. .build();
  199. SearchResp searchResp = client.hybridSearch(hybridSearchReq);
  200. List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
  201. for (int i = 0; i < NQ; i++) {
  202. System.out.printf("============= Search result of No.%d vector =============\n", i);
  203. List<SearchResp.SearchResult> results = searchResults.get(i);
  204. for (SearchResp.SearchResult result : results) {
  205. System.out.printf("{id: %d, score: %f}%n", result.getId(), result.getScore());
  206. }
  207. }
  208. }
  209. private void dropCollection() {
  210. client.dropCollection(DropCollectionReq.builder()
  211. .collectionName(COLLECTION_NAME)
  212. .build());
  213. System.out.println("Collection dropped");
  214. }
  215. public static void main(String[] args) {
  216. io.milvus.v2.HybridSearchExample example = new io.milvus.v2.HybridSearchExample();
  217. example.createCollection();
  218. example.insertData();
  219. example.hybridSearch();
  220. example.dropCollection();
  221. client.close();
  222. }
  223. }