Float16VectorExample.java 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus;
  20. import io.milvus.client.MilvusServiceClient;
  21. import io.milvus.common.clientenum.ConsistencyLevelEnum;
  22. import io.milvus.grpc.*;
  23. import io.milvus.param.*;
  24. import io.milvus.param.collection.*;
  25. import io.milvus.param.dml.*;
  26. import io.milvus.param.index.*;
  27. import io.milvus.response.*;
  28. import java.nio.ByteBuffer;
  29. import java.util.*;
  30. import org.tensorflow.ndarray.buffer.ByteDataBuffer;
  31. import org.tensorflow.types.*;
  32. public class Float16VectorExample {
  33. private static final String COLLECTION_NAME = "java_sdk_example_float16";
  34. private static final String ID_FIELD = "id";
  35. private static final String VECTOR_FIELD = "vector";
  36. private static final Integer VECTOR_DIM = 128;
  37. private static void testFloat16(boolean bfloat16) {
  38. DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
  39. System.out.printf("=================== %s ===================\n", dataType.name());
  40. // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
  41. MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
  42. .withHost("localhost")
  43. .withPort(19530)
  44. .build());
  45. // drop the collection if you don't need the collection anymore
  46. R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
  47. .withCollectionName(COLLECTION_NAME)
  48. .build());
  49. CommonUtils.handleResponseStatus(hasR);
  50. if (hasR.getData()) {
  51. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  52. .withCollectionName(COLLECTION_NAME)
  53. .build());
  54. }
  55. // Define fields
  56. List<FieldType> fieldsSchema = Arrays.asList(
  57. FieldType.newBuilder()
  58. .withName(ID_FIELD)
  59. .withDataType(DataType.Int64)
  60. .withPrimaryKey(true)
  61. .withAutoID(false)
  62. .build(),
  63. FieldType.newBuilder()
  64. .withName(VECTOR_FIELD)
  65. .withDataType(dataType)
  66. .withDimension(VECTOR_DIM)
  67. .build()
  68. );
  69. // Create the collection
  70. R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
  71. .withCollectionName(COLLECTION_NAME)
  72. .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
  73. .withFieldTypes(fieldsSchema)
  74. .build());
  75. CommonUtils.handleResponseStatus(ret);
  76. System.out.println("Collection created");
  77. // Insert entities
  78. int rowCount = 10000;
  79. List<Long> ids = new ArrayList<>();
  80. for (long i = 0L; i < rowCount; ++i) {
  81. ids.add(i);
  82. }
  83. List<ByteBuffer> vectors = CommonUtils.generateFloat16Vectors(VECTOR_DIM, rowCount, bfloat16);
  84. List<InsertParam.Field> fieldsInsert = new ArrayList<>();
  85. fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
  86. fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
  87. InsertParam insertParam = InsertParam.newBuilder()
  88. .withCollectionName(COLLECTION_NAME)
  89. .withFields(fieldsInsert)
  90. .build();
  91. R<MutationResult> insertR = milvusClient.insert(insertParam);
  92. CommonUtils.handleResponseStatus(insertR);
  93. // Flush the data to storage for testing purpose
  94. // Note that no need to manually call flush interface in practice
  95. R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
  96. addCollectionName(COLLECTION_NAME).
  97. build());
  98. CommonUtils.handleResponseStatus(flushR);
  99. System.out.println("Entities inserted");
  100. // Specify an index type on the vector field.
  101. ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
  102. .withCollectionName(COLLECTION_NAME)
  103. .withFieldName(VECTOR_FIELD)
  104. .withIndexType(IndexType.IVF_FLAT)
  105. .withMetricType(MetricType.L2)
  106. .withExtraParam("{\"nlist\":128}")
  107. .build());
  108. CommonUtils.handleResponseStatus(ret);
  109. System.out.println("Index created");
  110. // Call loadCollection() to enable automatically loading data into memory for searching
  111. ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
  112. .withCollectionName(COLLECTION_NAME)
  113. .build());
  114. CommonUtils.handleResponseStatus(ret);
  115. System.out.println("Collection loaded");
  116. // Pick some vectors from the inserted vectors to search
  117. // Ensure the returned top1 item's ID should be equal to target vector's ID
  118. for (int i = 0; i < 10; i++) {
  119. Random ran = new Random();
  120. int k = ran.nextInt(rowCount);
  121. ByteBuffer targetVector = vectors.get(k);
  122. SearchParam.Builder builder = SearchParam.newBuilder()
  123. .withCollectionName(COLLECTION_NAME)
  124. .withMetricType(MetricType.L2)
  125. .withTopK(3)
  126. .withVectorFieldName(VECTOR_FIELD)
  127. .addOutField(VECTOR_FIELD)
  128. .withParams("{\"nprobe\":32}");
  129. if (bfloat16) {
  130. builder.withBFloat16Vectors(Collections.singletonList(targetVector));
  131. } else {
  132. builder.withFloat16Vectors(Collections.singletonList(targetVector));
  133. }
  134. R<SearchResults> searchRet = milvusClient.search(builder.build());
  135. CommonUtils.handleResponseStatus(searchRet);
  136. // The search() allows multiple target vectors to search in a batch.
  137. // Here we only input one vector to search, get the result of No.0 vector to check
  138. SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
  139. List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
  140. System.out.printf("The result of No.%d target vector:\n", i);
  141. for (SearchResultsWrapper.IDScore score : scores) {
  142. System.out.println(score);
  143. }
  144. if (scores.get(0).getLongID() != k) {
  145. throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
  146. scores.get(0).getLongID(), k));
  147. }
  148. }
  149. System.out.println("Search result is correct");
  150. // Retrieve some data
  151. int n = 99;
  152. R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
  153. .withCollectionName(COLLECTION_NAME)
  154. .withExpr(String.format("id == %d", n))
  155. .addOutField(VECTOR_FIELD)
  156. .build());
  157. CommonUtils.handleResponseStatus(queryR);
  158. QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
  159. FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
  160. List<?> r = field.getFieldData();
  161. if (r.isEmpty()) {
  162. throw new RuntimeException("The query result is empty");
  163. } else {
  164. ByteBuffer bf = (ByteBuffer) r.get(0);
  165. if (!bf.equals(vectors.get(n))) {
  166. throw new RuntimeException("The query result is incorrect");
  167. }
  168. }
  169. System.out.println("Query result is correct");
  170. // drop the collection if you don't need the collection anymore
  171. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  172. .withCollectionName(COLLECTION_NAME)
  173. .build());
  174. System.out.println("Collection dropped");
  175. milvusClient.close();
  176. }
  177. public static void main(String[] args) {
  178. testFloat16(true);
  179. testFloat16(false);
  180. }
  181. }