Float16VectorExample.java 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.v1;
  20. import com.google.gson.Gson;
  21. import com.google.gson.JsonObject;
  22. import io.milvus.client.MilvusServiceClient;
  23. import io.milvus.common.clientenum.ConsistencyLevelEnum;
  24. import io.milvus.grpc.*;
  25. import io.milvus.param.*;
  26. import io.milvus.param.collection.*;
  27. import io.milvus.param.dml.*;
  28. import io.milvus.param.index.*;
  29. import io.milvus.response.*;
  30. import org.tensorflow.types.TBfloat16;
  31. import org.tensorflow.types.TFloat16;
  32. import java.nio.ByteBuffer;
  33. import java.util.*;
  34. public class Float16VectorExample {
  35. private static final String COLLECTION_NAME = "java_sdk_example_float16_vector_v1";
  36. private static final String ID_FIELD = "id";
  37. private static final String VECTOR_FIELD = "vector";
  38. private static final Integer VECTOR_DIM = 128;
  39. private static final MilvusServiceClient milvusClient;
  40. static {
  41. // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
  42. milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
  43. .withHost("localhost")
  44. .withPort(19530)
  45. .build());
  46. }
  47. // For float16 values between 0.0~1.0, the precision can be controlled under 0.001f
  48. // For bfloat16 values between 0.0~1.0, the precision can be controlled under 0.01f
  49. private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
  50. if (bfloat16) {
  51. return Math.abs(a - b) <= 0.01f;
  52. } else {
  53. return Math.abs(a - b) <= 0.001f;
  54. }
  55. }
  56. private static void createCollection(boolean bfloat16) {
  57. // drop the collection if you don't need the collection anymore
  58. R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
  59. .withCollectionName(COLLECTION_NAME)
  60. .build());
  61. CommonUtils.handleResponseStatus(hasR);
  62. if (hasR.getData()) {
  63. dropCollection();
  64. }
  65. // Define fields
  66. DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
  67. List<FieldType> fieldsSchema = Arrays.asList(
  68. FieldType.newBuilder()
  69. .withName(ID_FIELD)
  70. .withDataType(DataType.Int64)
  71. .withPrimaryKey(true)
  72. .withAutoID(false)
  73. .build(),
  74. FieldType.newBuilder()
  75. .withName(VECTOR_FIELD)
  76. .withDataType(dataType)
  77. .withDimension(VECTOR_DIM)
  78. .build()
  79. );
  80. // Create the collection
  81. // Note that we set default consistency level to "STRONG",
  82. // to ensure data is visible to search, for validation the result
  83. R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
  84. .withCollectionName(COLLECTION_NAME)
  85. .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
  86. .withFieldTypes(fieldsSchema)
  87. .build());
  88. CommonUtils.handleResponseStatus(ret);
  89. System.out.println("Collection created");
  90. // Specify an index type on the vector field.
  91. ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
  92. .withCollectionName(COLLECTION_NAME)
  93. .withFieldName(VECTOR_FIELD)
  94. .withIndexType(IndexType.IVF_FLAT)
  95. .withMetricType(MetricType.L2)
  96. .withExtraParam("{\"nlist\":128}")
  97. .build());
  98. CommonUtils.handleResponseStatus(ret);
  99. System.out.println("Index created");
  100. // Call loadCollection() to enable automatically loading data into memory for searching
  101. ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
  102. .withCollectionName(COLLECTION_NAME)
  103. .build());
  104. CommonUtils.handleResponseStatus(ret);
  105. System.out.println("Collection loaded");
  106. }
  107. private static void dropCollection() {
  108. milvusClient.dropCollection(DropCollectionParam.newBuilder()
  109. .withCollectionName(COLLECTION_NAME)
  110. .build());
  111. System.out.println("Collection dropped");
  112. }
  113. private static void testFloat16(boolean bfloat16) {
  114. DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
  115. System.out.printf("============ testFloat16 %s ===================\n", dataType.name());
  116. createCollection(bfloat16);
  117. // Insert 5000 entities by columns
  118. // Prepare original vectors, then encode into ByteBuffer
  119. int batchRowCount = 5000;
  120. List<List<Float>> originVectors = CommonUtils.generateFloatVectors(VECTOR_DIM, batchRowCount);
  121. List<ByteBuffer> encodedVectors = CommonUtils.encodeFloat16Vectors(originVectors, bfloat16);
  122. List<Long> ids = new ArrayList<>();
  123. for (long i = 0L; i < batchRowCount; ++i) {
  124. ids.add(i);
  125. }
  126. List<InsertParam.Field> fieldsInsert = new ArrayList<>();
  127. fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
  128. fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, encodedVectors));
  129. R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
  130. .withCollectionName(COLLECTION_NAME)
  131. .withFields(fieldsInsert)
  132. .build());
  133. CommonUtils.handleResponseStatus(insertR);
  134. System.out.println(ids.size() + " rows inserted");
  135. // Insert 5000 entities by rows
  136. List<JsonObject> rows = new ArrayList<>();
  137. Gson gson = new Gson();
  138. for (int i = 0; i < batchRowCount; ++i) {
  139. JsonObject row = new JsonObject();
  140. row.addProperty(ID_FIELD, batchRowCount + i);
  141. List<Float> originVector = CommonUtils.generateFloatVector(VECTOR_DIM);
  142. originVectors.add(originVector);
  143. ByteBuffer buf = CommonUtils.encodeFloat16Vector(originVector, bfloat16);
  144. encodedVectors.add(buf);
  145. row.add(VECTOR_FIELD, gson.toJsonTree(buf.array()));
  146. rows.add(row);
  147. }
  148. insertR = milvusClient.insert(InsertParam.newBuilder()
  149. .withCollectionName(COLLECTION_NAME)
  150. .withRows(rows)
  151. .build());
  152. CommonUtils.handleResponseStatus(insertR);
  153. System.out.println(ids.size() + " rows inserted");
  154. // Pick some random vectors from the original vectors to search
  155. // Ensure the returned top1 item's ID should be equal to target vector's ID
  156. for (int i = 0; i < 10; i++) {
  157. Random ran = new Random();
  158. int k = ran.nextInt(batchRowCount*2);
  159. ByteBuffer targetVector = encodedVectors.get(k);
  160. SearchParam.Builder builder = SearchParam.newBuilder()
  161. .withCollectionName(COLLECTION_NAME)
  162. .withMetricType(MetricType.L2)
  163. .withTopK(3)
  164. .withVectorFieldName(VECTOR_FIELD)
  165. .addOutField(VECTOR_FIELD)
  166. .withParams("{\"nprobe\":32}");
  167. if (bfloat16) {
  168. builder.withBFloat16Vectors(Collections.singletonList(targetVector));
  169. } else {
  170. builder.withFloat16Vectors(Collections.singletonList(targetVector));
  171. }
  172. R<SearchResults> searchRet = milvusClient.search(builder.build());
  173. CommonUtils.handleResponseStatus(searchRet);
  174. // The search() allows multiple target vectors to search in a batch.
  175. // Here we only input one vector to search, get the result of No.0 vector to check
  176. SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
  177. List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
  178. System.out.printf("The result of No.%d target vector:\n", i);
  179. for (SearchResultsWrapper.IDScore score : scores) {
  180. System.out.println(score);
  181. }
  182. SearchResultsWrapper.IDScore firstScore = scores.get(0);
  183. if (firstScore.getLongID() != k) {
  184. throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
  185. firstScore.getLongID(), k));
  186. }
  187. ByteBuffer outputBuf = (ByteBuffer)firstScore.get(VECTOR_FIELD);
  188. if (!outputBuf.equals(targetVector)) {
  189. throw new RuntimeException(String.format("The output vector is not equal to target vector: ID %d", k));
  190. }
  191. List<Float> outputVector = CommonUtils.decodeFloat16Vector(outputBuf, bfloat16);
  192. List<Float> originVector = originVectors.get(k);
  193. for (int j = 0; j < outputVector.size(); j++) {
  194. if (!isFloat16Eauql(outputVector.get(j), originVector.get(j), bfloat16)) {
  195. throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
  196. }
  197. }
  198. }
  199. System.out.println("Search result is correct");
  200. // Retrieve some data and verify the output
  201. for (int i = 0; i < 10; i++) {
  202. Random ran = new Random();
  203. int k = ran.nextInt(batchRowCount*2);
  204. R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
  205. .withCollectionName(COLLECTION_NAME)
  206. .withExpr(String.format("id == %d", k))
  207. .addOutField(VECTOR_FIELD)
  208. .build());
  209. CommonUtils.handleResponseStatus(queryR);
  210. QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
  211. FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
  212. List<?> r = field.getFieldData();
  213. if (r.isEmpty()) {
  214. throw new RuntimeException("The query result is empty");
  215. } else {
  216. ByteBuffer outputBuf = (ByteBuffer) r.get(0);
  217. ByteBuffer targetVector = encodedVectors.get(k);
  218. if (!outputBuf.equals(targetVector)) {
  219. throw new RuntimeException("The query result is incorrect");
  220. }
  221. List<Float> outputVector = CommonUtils.decodeFloat16Vector(outputBuf, bfloat16);
  222. List<Float> originVector = originVectors.get(k);
  223. for (int j = 0; j < outputVector.size(); j++) {
  224. if (!isFloat16Eauql(outputVector.get(j), originVector.get(j), bfloat16)) {
  225. throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
  226. }
  227. }
  228. }
  229. }
  230. System.out.println("Query result is correct");
  231. // drop the collection if you don't need the collection anymore
  232. dropCollection();
  233. }
  234. private static void testTensorflowFloat16(boolean bfloat16) {
  235. DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
  236. System.out.printf("============ testTensorflowFloat16 %s ===================\n", dataType.name());
  237. createCollection(bfloat16);
  238. // Prepare tensorflow vectors, convert to ByteBuffer and insert
  239. int rowCount = 10000;
  240. List<Long> ids = new ArrayList<>();
  241. for (long i = 0L; i < rowCount; ++i) {
  242. ids.add(i);
  243. }
  244. List<InsertParam.Field> fieldsInsert = new ArrayList<>();
  245. fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
  246. List<ByteBuffer> encodedVectors;
  247. if (bfloat16) {
  248. List<TBfloat16> tfVectors = CommonUtils.genTensorflowBF16Vectors(VECTOR_DIM, rowCount);
  249. encodedVectors = CommonUtils.encodeTensorBF16Vectors(tfVectors);
  250. } else {
  251. List<TFloat16> tfVectors = CommonUtils.genTensorflowFP16Vectors(VECTOR_DIM, rowCount);
  252. encodedVectors = CommonUtils.encodeTensorFP16Vectors(tfVectors);
  253. }
  254. fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, encodedVectors));
  255. R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
  256. .withCollectionName(COLLECTION_NAME)
  257. .withFields(fieldsInsert)
  258. .build());
  259. CommonUtils.handleResponseStatus(insertR);
  260. System.out.println(ids.size() + " rows inserted");
  261. // Retrieve some data and verify the output
  262. Random ran = new Random();
  263. int k = ran.nextInt(rowCount);
  264. R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
  265. .withCollectionName(COLLECTION_NAME)
  266. .withExpr(String.format("id == %d", k))
  267. .addOutField(VECTOR_FIELD)
  268. .build());
  269. CommonUtils.handleResponseStatus(queryR);
  270. QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
  271. FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
  272. List<?> r = field.getFieldData();
  273. if (r.isEmpty()) {
  274. throw new RuntimeException("The query result is empty");
  275. }
  276. ByteBuffer outputBuf = (ByteBuffer) r.get(0);
  277. ByteBuffer originVector = encodedVectors.get(k);
  278. if (!outputBuf.equals(originVector)) {
  279. throw new RuntimeException("The query result is incorrect");
  280. }
  281. List<Float> vector = new ArrayList<>();
  282. if (bfloat16) {
  283. TBfloat16 tf = CommonUtils.decodeTensorBF16Vector(outputBuf);
  284. for (long i = 0; i < tf.size(); i++) {
  285. vector.add(tf.getFloat(i));
  286. }
  287. } else {
  288. TFloat16 tf = CommonUtils.decodeTensorFP16Vector(outputBuf);
  289. for (long i = 0; i < tf.size(); i++) {
  290. vector.add(tf.getFloat(i));
  291. }
  292. }
  293. System.out.println(vector);
  294. System.out.println("Query result is correct");
  295. // drop the collection if you don't need the collection anymore
  296. dropCollection();
  297. }
  298. public static void main(String[] args) {
  299. testFloat16(true);
  300. testFloat16(false);
  301. testTensorflowFloat16(true);
  302. testTensorflowFloat16(false);
  303. }
  304. }