Float16VectorExample.java 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v1;
  20. import com.google.gson.Gson;
  21. import com.google.gson.JsonObject;
  22. import io.milvus.client.MilvusServiceClient;
  23. import io.milvus.common.clientenum.ConsistencyLevelEnum;
  24. import io.milvus.grpc.*;
  25. import io.milvus.param.*;
  26. import io.milvus.param.collection.*;
  27. import io.milvus.param.dml.*;
  28. import io.milvus.param.index.*;
  29. import io.milvus.response.*;
  30. import java.nio.ByteBuffer;
  31. import java.util.*;
  32. public class Float16VectorExample {
  33. private static final String COLLECTION_NAME = "java_sdk_example_float16";
  34. private static final String ID_FIELD = "id";
  35. private static final String VECTOR_FIELD = "vector";
  36. private static final Integer VECTOR_DIM = 128;
  37. private static void testFloat16(boolean bfloat16) {
  38. DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
  39. System.out.printf("=================== %s ===================\n", dataType.name());
  40. // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
  41. MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
  42. .withHost("localhost")
  43. .withPort(19530)
  44. .build());
  45. // drop the collection if you don't need the collection anymore
  46. R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
  47. .withCollectionName(COLLECTION_NAME)
  48. .build());
  49. CommonUtils.handleResponseStatus(hasR);
  50. if (hasR.getData()) {
  51. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  52. .withCollectionName(COLLECTION_NAME)
  53. .build());
  54. }
  55. // Define fields
  56. List<FieldType> fieldsSchema = Arrays.asList(
  57. FieldType.newBuilder()
  58. .withName(ID_FIELD)
  59. .withDataType(DataType.Int64)
  60. .withPrimaryKey(true)
  61. .withAutoID(false)
  62. .build(),
  63. FieldType.newBuilder()
  64. .withName(VECTOR_FIELD)
  65. .withDataType(dataType)
  66. .withDimension(VECTOR_DIM)
  67. .build()
  68. );
  69. // Create the collection
  70. R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
  71. .withCollectionName(COLLECTION_NAME)
  72. .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
  73. .withFieldTypes(fieldsSchema)
  74. .build());
  75. CommonUtils.handleResponseStatus(ret);
  76. System.out.println("Collection created");
  77. // Insert entities by columns
  78. int rowCount = 10000;
  79. List<Long> ids = new ArrayList<>();
  80. for (long i = 0L; i < rowCount; ++i) {
  81. ids.add(i);
  82. }
  83. List<ByteBuffer> vectors = CommonUtils.generateFloat16Vectors(VECTOR_DIM, rowCount, bfloat16);
  84. List<InsertParam.Field> fieldsInsert = new ArrayList<>();
  85. fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
  86. fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
  87. R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
  88. .withCollectionName(COLLECTION_NAME)
  89. .withFields(fieldsInsert)
  90. .build());
  91. CommonUtils.handleResponseStatus(insertR);
  92. // Insert entities by rows
  93. List<JsonObject> rows = new ArrayList<>();
  94. Gson gson = new Gson();
  95. for (long i = 1L; i <= rowCount; ++i) {
  96. JsonObject row = new JsonObject();
  97. row.addProperty(ID_FIELD, rowCount + i);
  98. row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloat16Vector(VECTOR_DIM, bfloat16).array()));
  99. rows.add(row);
  100. }
  101. insertR = milvusClient.insert(InsertParam.newBuilder()
  102. .withCollectionName(COLLECTION_NAME)
  103. .withRows(rows)
  104. .build());
  105. CommonUtils.handleResponseStatus(insertR);
  106. // Flush the data to storage for testing purpose
  107. // Note that no need to manually call flush interface in practice
  108. R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
  109. addCollectionName(COLLECTION_NAME).
  110. build());
  111. CommonUtils.handleResponseStatus(flushR);
  112. System.out.println("Entities inserted");
  113. // Specify an index type on the vector field.
  114. ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
  115. .withCollectionName(COLLECTION_NAME)
  116. .withFieldName(VECTOR_FIELD)
  117. .withIndexType(IndexType.IVF_FLAT)
  118. .withMetricType(MetricType.L2)
  119. .withExtraParam("{\"nlist\":128}")
  120. .build());
  121. CommonUtils.handleResponseStatus(ret);
  122. System.out.println("Index created");
  123. // Call loadCollection() to enable automatically loading data into memory for searching
  124. ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
  125. .withCollectionName(COLLECTION_NAME)
  126. .build());
  127. CommonUtils.handleResponseStatus(ret);
  128. System.out.println("Collection loaded");
  129. // Pick some vectors from the inserted vectors to search
  130. // Ensure the returned top1 item's ID should be equal to target vector's ID
  131. for (int i = 0; i < 10; i++) {
  132. Random ran = new Random();
  133. int k = ran.nextInt(rowCount);
  134. ByteBuffer targetVector = vectors.get(k);
  135. SearchParam.Builder builder = SearchParam.newBuilder()
  136. .withCollectionName(COLLECTION_NAME)
  137. .withMetricType(MetricType.L2)
  138. .withTopK(3)
  139. .withVectorFieldName(VECTOR_FIELD)
  140. .addOutField(VECTOR_FIELD)
  141. .withParams("{\"nprobe\":32}");
  142. if (bfloat16) {
  143. builder.withBFloat16Vectors(Collections.singletonList(targetVector));
  144. } else {
  145. builder.withFloat16Vectors(Collections.singletonList(targetVector));
  146. }
  147. R<SearchResults> searchRet = milvusClient.search(builder.build());
  148. CommonUtils.handleResponseStatus(searchRet);
  149. // The search() allows multiple target vectors to search in a batch.
  150. // Here we only input one vector to search, get the result of No.0 vector to check
  151. SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
  152. List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
  153. System.out.printf("The result of No.%d target vector:\n", i);
  154. for (SearchResultsWrapper.IDScore score : scores) {
  155. System.out.println(score);
  156. }
  157. if (scores.get(0).getLongID() != k) {
  158. throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
  159. scores.get(0).getLongID(), k));
  160. }
  161. }
  162. System.out.println("Search result is correct");
  163. // Retrieve some data
  164. int n = 99;
  165. R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
  166. .withCollectionName(COLLECTION_NAME)
  167. .withExpr(String.format("id == %d", n))
  168. .addOutField(VECTOR_FIELD)
  169. .build());
  170. CommonUtils.handleResponseStatus(queryR);
  171. QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
  172. FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
  173. List<?> r = field.getFieldData();
  174. if (r.isEmpty()) {
  175. throw new RuntimeException("The query result is empty");
  176. } else {
  177. ByteBuffer bf = (ByteBuffer) r.get(0);
  178. if (!bf.equals(vectors.get(n))) {
  179. throw new RuntimeException("The query result is incorrect");
  180. }
  181. }
  182. System.out.println("Query result is correct");
  183. // drop the collection if you don't need the collection anymore
  184. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  185. .withCollectionName(COLLECTION_NAME)
  186. .build());
  187. System.out.println("Collection dropped");
  188. milvusClient.close();
  189. }
  190. public static void main(String[] args) {
  191. testFloat16(true);
  192. testFloat16(false);
  193. }
  194. }