SparseVectorExample.java 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus;
  20. import io.milvus.client.MilvusServiceClient;
  21. import io.milvus.common.clientenum.ConsistencyLevelEnum;
  22. import io.milvus.grpc.*;
  23. import io.milvus.param.*;
  24. import io.milvus.param.collection.*;
  25. import io.milvus.param.dml.InsertParam;
  26. import io.milvus.param.dml.QueryParam;
  27. import io.milvus.param.dml.SearchParam;
  28. import io.milvus.param.index.CreateIndexParam;
  29. import io.milvus.response.FieldDataWrapper;
  30. import io.milvus.response.QueryResultsWrapper;
  31. import io.milvus.response.SearchResultsWrapper;
  32. import java.util.*;
  33. public class SparseVectorExample {
  34. private static final String COLLECTION_NAME = "java_sdk_example_sparse";
  35. private static final String ID_FIELD = "id";
  36. private static final String VECTOR_FIELD = "vector";
  37. public static void main(String[] args) {
  38. // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
  39. MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
  40. .withHost("localhost")
  41. .withPort(19530)
  42. .build());
  43. // drop the collection if you don't need the collection anymore
  44. R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
  45. .withCollectionName(COLLECTION_NAME)
  46. .build());
  47. CommonUtils.handleResponseStatus(hasR);
  48. if (hasR.getData()) {
  49. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  50. .withCollectionName(COLLECTION_NAME)
  51. .build());
  52. }
  53. // Define fields
  54. List<FieldType> fieldsSchema = Arrays.asList(
  55. FieldType.newBuilder()
  56. .withName(ID_FIELD)
  57. .withDataType(DataType.Int64)
  58. .withPrimaryKey(true)
  59. .withAutoID(false)
  60. .build(),
  61. FieldType.newBuilder()
  62. .withName(VECTOR_FIELD)
  63. .withDataType(DataType.SparseFloatVector)
  64. .build()
  65. );
  66. // Create the collection
  67. R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
  68. .withCollectionName(COLLECTION_NAME)
  69. .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
  70. .withFieldTypes(fieldsSchema)
  71. .build());
  72. CommonUtils.handleResponseStatus(ret);
  73. System.out.println("Collection created");
  74. // Insert entities
  75. int rowCount = 10000;
  76. List<Long> ids = new ArrayList<>();
  77. for (long i = 0L; i < rowCount; ++i) {
  78. ids.add(i);
  79. }
  80. List<SortedMap<Long, Float>> vectors = CommonUtils.generateSparseVectors(rowCount);
  81. List<InsertParam.Field> fieldsInsert = new ArrayList<>();
  82. fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
  83. fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
  84. InsertParam insertParam = InsertParam.newBuilder()
  85. .withCollectionName(COLLECTION_NAME)
  86. .withFields(fieldsInsert)
  87. .build();
  88. R<MutationResult> insertR = milvusClient.insert(insertParam);
  89. CommonUtils.handleResponseStatus(insertR);
  90. // Flush the data to storage for testing purpose
  91. // Note that no need to manually call flush interface in practice
  92. R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
  93. addCollectionName(COLLECTION_NAME).
  94. build());
  95. CommonUtils.handleResponseStatus(flushR);
  96. System.out.println("Entities inserted");
  97. // Specify an index type on the vector field.
  98. ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
  99. .withCollectionName(COLLECTION_NAME)
  100. .withFieldName(VECTOR_FIELD)
  101. .withIndexType(IndexType.SPARSE_WAND)
  102. .withMetricType(MetricType.IP)
  103. .withExtraParam("{\"drop_ratio_build\":0.2}")
  104. .build());
  105. CommonUtils.handleResponseStatus(ret);
  106. System.out.println("Index created");
  107. // Call loadCollection() to enable automatically loading data into memory for searching
  108. ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
  109. .withCollectionName(COLLECTION_NAME)
  110. .build());
  111. CommonUtils.handleResponseStatus(ret);
  112. System.out.println("Collection loaded");
  113. // Pick some vectors from the inserted vectors to search
  114. // Ensure the returned top1 item's ID should be equal to target vector's ID
  115. for (int i = 0; i < 10; i++) {
  116. Random ran = new Random();
  117. int k = ran.nextInt(rowCount);
  118. SortedMap<Long, Float> targetVector = vectors.get(k);
  119. R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
  120. .withCollectionName(COLLECTION_NAME)
  121. .withMetricType(MetricType.IP)
  122. .withTopK(3)
  123. .withSparseFloatVectors(Collections.singletonList(targetVector))
  124. .withVectorFieldName(VECTOR_FIELD)
  125. .addOutField(VECTOR_FIELD)
  126. .withParams("{\"drop_ratio_search\":0.2}")
  127. .build());
  128. CommonUtils.handleResponseStatus(searchRet);
  129. // The search() allows multiple target vectors to search in a batch.
  130. // Here we only input one vector to search, get the result of No.0 vector to check
  131. SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
  132. List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
  133. System.out.printf("The result of No.%d target vector:\n", i);
  134. for (SearchResultsWrapper.IDScore score : scores) {
  135. System.out.println(score);
  136. }
  137. if (scores.get(0).getLongID() != k) {
  138. throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
  139. scores.get(0).getLongID(), k));
  140. }
  141. }
  142. System.out.println("Search result is correct");
  143. // Retrieve some data
  144. int n = 99;
  145. R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
  146. .withCollectionName(COLLECTION_NAME)
  147. .withExpr(String.format("id == %d", n))
  148. .addOutField(VECTOR_FIELD)
  149. .build());
  150. CommonUtils.handleResponseStatus(queryR);
  151. QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
  152. FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
  153. List<?> r = field.getFieldData();
  154. if (r.isEmpty()) {
  155. throw new RuntimeException("The query result is empty");
  156. } else {
  157. SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) r.get(0);
  158. if (!sparse.equals(vectors.get(n))) {
  159. throw new RuntimeException("The query result is incorrect");
  160. }
  161. }
  162. System.out.println("Query result is correct");
  163. // drop the collection if you don't need the collection anymore
  164. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  165. .withCollectionName(COLLECTION_NAME)
  166. .build());
  167. System.out.println("Collection dropped");
  168. milvusClient.close();
  169. }
  170. }