Ver código fonte

Improve usability of float16 (#940)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 10 meses atrás
pai
commit
c8c5924021

+ 2 - 2
docker-compose.yml

@@ -32,7 +32,7 @@ services:
 
   standalone:
     container_name: milvus-javasdk-test-standalone
-    image: milvusdb/milvus:v2.4.0-20240416-ffb6edd4-amd64
+    image: milvusdb/milvus:v2.4.4
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcd:2379
@@ -77,7 +77,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-test-slave-standalone
-    image: milvusdb/milvus:v2.4.0-20240416-ffb6edd4-amd64
+    image: milvusdb/milvus:v2.4.4
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcdslave:2379

+ 135 - 16
examples/main/java/io/milvus/v1/CommonUtils.java

@@ -18,8 +18,13 @@
  */
 package io.milvus.v1;
 
+import io.milvus.common.utils.Float16Utils;
 import io.milvus.param.R;
+
+import org.tensorflow.Tensor;
+import org.tensorflow.ndarray.Shape;
 import org.tensorflow.ndarray.buffer.ByteDataBuffer;
+import org.tensorflow.ndarray.buffer.DataBuffers;
 import org.tensorflow.types.TBfloat16;
 import org.tensorflow.types.TFloat16;
 
@@ -70,9 +75,11 @@ public class CommonUtils {
         return vectors;
     }
 
+    /////////////////////////////////////////////////////////////////////////////////////////////////////
     public static ByteBuffer generateBinaryVector(int dimension) {
         Random ran = new Random();
         int byteCount = dimension / 8;
+        // binary vector doesn't care endian since each byte is independent
         ByteBuffer vector = ByteBuffer.allocate(byteCount);
         for (int i = 0; i < byteCount; ++i) {
             vector.put((byte) ran.nextInt(Byte.MAX_VALUE));
@@ -89,38 +96,150 @@ public class CommonUtils {
         return vectors;
     }
 
-    public static ByteBuffer generateFloat16Vector(int dimension, boolean bfloat16) {
+    /////////////////////////////////////////////////////////////////////////////////////////////////////
+    public static TBfloat16 genTensorflowBF16Vector(int dimension) {
         Random ran = new Random();
-        int byteCount = dimension*2;
-        ByteBuffer vector = ByteBuffer.allocate(byteCount);
-        for (int i = 0; i < dimension; ++i) {
-            ByteDataBuffer bf;
+        float[] array = new float[dimension];
+        for (int n = 0; n < dimension; ++n) {
+            array[n] = ran.nextFloat();
+        }
+
+        return TBfloat16.vectorOf(array);
+    }
+
+    public static List<TBfloat16> genTensorflowBF16Vectors(int dimension, int count) {
+        List<TBfloat16> vectors = new ArrayList<>();
+        for (int n = 0; n < count; ++n) {
+           TBfloat16 vector = genTensorflowBF16Vector(dimension);
+            vectors.add(vector);
+        }
+
+        return vectors;
+    }
+
+    public static ByteBuffer encodeTensorBF16Vector(TBfloat16 vector) {
+        ByteDataBuffer tensorBuf = vector.asRawTensor().data();
+        ByteBuffer buf = ByteBuffer.allocate((int)tensorBuf.size());
+        for (long i = 0; i < tensorBuf.size(); i++) {
+            buf.put(tensorBuf.getByte(i));
+        }
+        return buf;
+    }
+
+    public static List<ByteBuffer> encodeTensorBF16Vectors(List<TBfloat16> vectors) {
+        List<ByteBuffer> buffers = new ArrayList<>();
+        for (TBfloat16 tf : vectors) {
+            ByteBuffer bf = encodeTensorBF16Vector(tf);
+            buffers.add(bf);
+        }
+        return buffers;
+    }
+
+    public static TBfloat16 decodeTensorBF16Vector(ByteBuffer buf) {
+        if (buf.limit()%2 != 0) {
+            return null;
+        }
+        int dim = buf.limit()/2;
+        ByteDataBuffer bf = DataBuffers.of(buf.array());
+        return Tensor.of(TBfloat16.class, Shape.of(dim), bf);
+    }
+
+
+    public static TFloat16 genTensorflowFP16Vector(int dimension) {
+        Random ran = new Random();
+        float[] array = new float[dimension];
+        for (int n = 0; n < dimension; ++n) {
+            array[n] = ran.nextFloat();
+        }
+
+        return TFloat16.vectorOf(array);
+    }
+
+    public static List<TFloat16> genTensorflowFP16Vectors(int dimension, int count) {
+        List<TFloat16> vectors = new ArrayList<>();
+        for (int n = 0; n < count; ++n) {
+            TFloat16 vector = genTensorflowFP16Vector(dimension);
+            vectors.add(vector);
+        }
+
+        return vectors;
+    }
+
+    public static ByteBuffer encodeTensorFP16Vector(TFloat16 vector) {
+        ByteDataBuffer tensorBuf = vector.asRawTensor().data();
+        ByteBuffer buf = ByteBuffer.allocate((int)tensorBuf.size());
+        for (long i = 0; i < tensorBuf.size(); i++) {
+            buf.put(tensorBuf.getByte(i));
+        }
+        return buf;
+    }
+
+    public static List<ByteBuffer> encodeTensorFP16Vectors(List<TFloat16> vectors) {
+        List<ByteBuffer> buffers = new ArrayList<>();
+        for (TFloat16 tf : vectors) {
+            ByteBuffer bf = encodeTensorFP16Vector(tf);
+            buffers.add(bf);
+        }
+        return buffers;
+    }
+
+    public static TFloat16 decodeTensorFP16Vector(ByteBuffer buf) {
+        if (buf.limit()%2 != 0) {
+            return null;
+        }
+        int dim = buf.limit()/2;
+        ByteDataBuffer bf = DataBuffers.of(buf.array());
+        return Tensor.of(TFloat16.class, Shape.of(dim), bf);
+    }
+
+    /////////////////////////////////////////////////////////////////////////////////////////////////////
+    public static ByteBuffer encodeFloat16Vector(List<Float> originVector, boolean bfloat16) {
+        if (bfloat16) {
+            return Float16Utils.f32VectorToBf16Buffer(originVector);
+        } else {
+            return Float16Utils.f32VectorToFp16Buffer(originVector);
+        }
+    }
+
+    public static List<Float> decodeFloat16Vector(ByteBuffer buf, boolean bfloat16) {
+        if (bfloat16) {
+            return Float16Utils.bf16BufferToVector(buf);
+        } else {
+            return Float16Utils.fp16BufferToVector(buf);
+        }
+    }
+
+    public static List<ByteBuffer> encodeFloat16Vectors(List<List<Float>> originVectors, boolean bfloat16) {
+        List<ByteBuffer> vectors = new ArrayList<>();
+        for (List<Float> originVector : originVectors) {
             if (bfloat16) {
-                TBfloat16 tt = TBfloat16.scalarOf((float)ran.nextInt(dimension));
-                bf = tt.asRawTensor().data();
+                vectors.add(Float16Utils.f32VectorToBf16Buffer(originVector));
             } else {
-                TFloat16 tt = TFloat16.scalarOf((float)ran.nextInt(dimension));
-                bf = tt.asRawTensor().data();
+                vectors.add(Float16Utils.f32VectorToFp16Buffer(originVector));
             }
-            vector.put(bf.getByte(0));
-            vector.put(bf.getByte(1));
         }
-        return vector;
+        return vectors;
+    }
+
+    public static ByteBuffer generateFloat16Vector(int dimension, boolean bfloat16) {
+        List<Float> originalVector = generateFloatVector(dimension);
+        return encodeFloat16Vector(originalVector, bfloat16);
     }
 
     public static List<ByteBuffer> generateFloat16Vectors(int dimension, int count, boolean bfloat16) {
         List<ByteBuffer> vectors = new ArrayList<>();
-        for (int n = 0; n < count; ++n) {
-            ByteBuffer vector = generateFloat16Vector(dimension, bfloat16);
-            vectors.add(vector);
+        for (int i = 0; i < count; i++) {
+            ByteBuffer buf = generateFloat16Vector(dimension, bfloat16);
+            vectors.add((buf));
         }
         return vectors;
     }
 
+    /////////////////////////////////////////////////////////////////////////////////////////////////////
     public static SortedMap<Long, Float> generateSparseVector() {
         Random ran = new Random();
         SortedMap<Long, Float> sparse = new TreeMap<>();
-        int dim = ran.nextInt(10) + 1;
+        int dim = ran.nextInt(10) + 10;
         for (int i = 0; i < dim; ++i) {
             sparse.put((long)ran.nextInt(1000000), ran.nextFloat());
         }

+ 184 - 147
examples/main/java/io/milvus/v1/Float16VectorExample.java

@@ -28,47 +28,51 @@ import io.milvus.param.collection.*;
 import io.milvus.param.dml.*;
 import io.milvus.param.index.*;
 import io.milvus.response.*;
+import org.tensorflow.types.TBfloat16;
+import org.tensorflow.types.TFloat16;
 
 import java.nio.ByteBuffer;
 import java.util.*;
 
-import org.tensorflow.Tensor;
-import org.tensorflow.ndarray.Shape;
-import org.tensorflow.ndarray.buffer.ByteDataBuffer;
-import org.tensorflow.ndarray.buffer.DataBuffers;
-import org.tensorflow.types.TBfloat16;
-import org.tensorflow.types.TFloat16;
-
 
 public class Float16VectorExample {
     private static final String COLLECTION_NAME = "java_sdk_example_float16";
     private static final String ID_FIELD = "id";
     private static final String VECTOR_FIELD = "vector";
     private static final Integer VECTOR_DIM = 128;
-    
-
-    private static void testFloat16(boolean bfloat16) {
-        DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
-        System.out.printf("=================== %s ===================\n", dataType.name());
 
+    private static final MilvusServiceClient milvusClient;
+    static {
         // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
-        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
+        milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
                 .withHost("localhost")
                 .withPort(19530)
                 .build());
+    }
+
+    // For float16 values between 0.0~1.0, the precision can be controlled under 0.001f
+    // For bfloat16 values between 0.0~1.0, the precision can be controlled under 0.01f
+    private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
+        if (bfloat16) {
+            return Math.abs(a - b) <= 0.01f;
+        } else {
+            return Math.abs(a - b) <= 0.001f;
+        }
+    }
+
+    private static void createCollection(boolean bfloat16) {
 
         // drop the collection if you don't need the collection anymore
         R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
-                                    .withCollectionName(COLLECTION_NAME)
-                                    .build());
+                .withCollectionName(COLLECTION_NAME)
+                .build());
         CommonUtils.handleResponseStatus(hasR);
         if (hasR.getData()) {
-            milvusClient.dropCollection(DropCollectionParam.newBuilder()
-                    .withCollectionName(COLLECTION_NAME)
-                    .build());
+            dropCollection();
         }
 
         // Define fields
+        DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
         List<FieldType> fieldsSchema = Arrays.asList(
                 FieldType.newBuilder()
                         .withName(ID_FIELD)
@@ -84,6 +88,8 @@ public class Float16VectorExample {
         );
 
         // Create the collection
+        // Note that we set default consistency level to "STRONG",
+        // to ensure data is visible to search, for validation the result
         R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
@@ -92,31 +98,73 @@ public class Float16VectorExample {
         CommonUtils.handleResponseStatus(ret);
         System.out.println("Collection created");
 
-        // Insert entities by columns
-        int rowCount = 10000;
+        // Specify an index type on the vector field.
+        ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFieldName(VECTOR_FIELD)
+                .withIndexType(IndexType.IVF_FLAT)
+                .withMetricType(MetricType.L2)
+                .withExtraParam("{\"nlist\":128}")
+                .build());
+        CommonUtils.handleResponseStatus(ret);
+        System.out.println("Index created");
+
+        // Call loadCollection() to enable automatically loading data into memory for searching
+        ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        CommonUtils.handleResponseStatus(ret);
+        System.out.println("Collection loaded");
+    }
+
+    private static void dropCollection() {
+        milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("Collection dropped");
+    }
+
+    private static void testFloat16(boolean bfloat16) {
+        DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
+        System.out.printf("============ testFloat16 %s ===================\n", dataType.name());
+
+        createCollection(bfloat16);
+
+        // Insert 5000 entities by columns
+        // Prepare original vectors, then encode into ByteBuffer
+        int batchRowCount = 5000;
+        List<List<Float>> originVectors = CommonUtils.generateFloatVectors(VECTOR_DIM, batchRowCount);
+        List<ByteBuffer> encodedVectors = CommonUtils.encodeFloat16Vectors(originVectors, bfloat16);
+
         List<Long> ids = new ArrayList<>();
-        for (long i = 0L; i < rowCount; ++i) {
+        for (long i = 0L; i < batchRowCount; ++i) {
             ids.add(i);
         }
-        List<ByteBuffer> vectors = CommonUtils.generateFloat16Vectors(VECTOR_DIM, rowCount, bfloat16);
-
         List<InsertParam.Field> fieldsInsert = new ArrayList<>();
         fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
-        fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
+        fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, encodedVectors));
 
         R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withFields(fieldsInsert)
                 .build());
         CommonUtils.handleResponseStatus(insertR);
+        System.out.println(ids.size() + " rows inserted");
 
-        // Insert entities by rows
+        // Insert 5000 entities by rows
         List<JsonObject> rows = new ArrayList<>();
         Gson gson = new Gson();
-        for (long i = 1L; i <= rowCount; ++i) {
+        for (int i = 0; i < batchRowCount; ++i) {
             JsonObject row = new JsonObject();
-            row.addProperty(ID_FIELD, rowCount + i);
-            row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloat16Vector(VECTOR_DIM, bfloat16).array()));
+            row.addProperty(ID_FIELD, batchRowCount + i);
+
+            List<Float> originVector = CommonUtils.generateFloatVector(VECTOR_DIM);
+            originVectors.add(originVector);
+
+            ByteBuffer buf = CommonUtils.encodeFloat16Vector(originVector, bfloat16);
+            encodedVectors.add(buf);
+
+            row.add(VECTOR_FIELD, gson.toJsonTree(buf.array()));
             rows.add(row);
         }
 
@@ -125,39 +173,14 @@ public class Float16VectorExample {
                 .withRows(rows)
                 .build());
         CommonUtils.handleResponseStatus(insertR);
+        System.out.println(ids.size() + " rows inserted");
 
-        // Flush the data to storage for testing purpose
-        // Note that no need to manually call flush interface in practice
-        R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
-                addCollectionName(COLLECTION_NAME).
-                build());
-        CommonUtils.handleResponseStatus(flushR);
-        System.out.println("Entities inserted");
-
-        // Specify an index type on the vector field.
-        ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
-                .withCollectionName(COLLECTION_NAME)
-                .withFieldName(VECTOR_FIELD)
-                .withIndexType(IndexType.IVF_FLAT)
-                .withMetricType(MetricType.L2)
-                .withExtraParam("{\"nlist\":128}")
-                .build());
-        CommonUtils.handleResponseStatus(ret);
-        System.out.println("Index created");
-
-        // Call loadCollection() to enable automatically loading data into memory for searching
-        ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
-                .withCollectionName(COLLECTION_NAME)
-                .build());
-        CommonUtils.handleResponseStatus(ret);
-        System.out.println("Collection loaded");
-
-        // Pick some vectors from the inserted vectors to search
+        // Pick some random vectors from the original vectors to search
         // Ensure the returned top1 item's ID should be equal to target vector's ID
         for (int i = 0; i < 10; i++) {
             Random ran = new Random();
-            int k = ran.nextInt(rowCount);
-            ByteBuffer targetVector = vectors.get(k);
+            int k = ran.nextInt(batchRowCount*2);
+            ByteBuffer targetVector = encodedVectors.get(k);
             SearchParam.Builder builder = SearchParam.newBuilder()
                     .withCollectionName(COLLECTION_NAME)
                     .withMetricType(MetricType.L2)
@@ -181,128 +204,142 @@ public class Float16VectorExample {
             for (SearchResultsWrapper.IDScore score : scores) {
                 System.out.println(score);
             }
-            if (scores.get(0).getLongID() != k) {
+
+            SearchResultsWrapper.IDScore firstScore = scores.get(0);
+            if (firstScore.getLongID() != k) {
                 throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
-                        scores.get(0).getLongID(), k));
+                        firstScore.getLongID(), k));
+            }
+
+            ByteBuffer outputBuf = (ByteBuffer)firstScore.get(VECTOR_FIELD);
+            if (!outputBuf.equals(targetVector)) {
+                throw new RuntimeException(String.format("The output vector is not equal to target vector: ID %d", k));
+            }
+
+            List<Float> outputVector = CommonUtils.decodeFloat16Vector(outputBuf, bfloat16);
+            List<Float> originVector = originVectors.get(k);
+            for (int j = 0; j < outputVector.size(); j++) {
+                if (!isFloat16Eauql(outputVector.get(j), originVector.get(j), bfloat16)) {
+                    throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
+                }
             }
         }
         System.out.println("Search result is correct");
 
-        // Retrieve some data
-        int n = 99;
-        R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
-                .withCollectionName(COLLECTION_NAME)
-                .withExpr(String.format("id == %d", n))
-                .addOutField(VECTOR_FIELD)
-                .build());
-        CommonUtils.handleResponseStatus(queryR);
-        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
-        FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
-        List<?> r = field.getFieldData();
-        if (r.isEmpty()) {
-            throw new RuntimeException("The query result is empty");
-        } else {
-            ByteBuffer bf = (ByteBuffer) r.get(0);
-            if (!bf.equals(vectors.get(n))) {
-                throw new RuntimeException("The query result is incorrect");
+        // Retrieve some data and verify the output
+        for (int i = 0; i < 10; i++) {
+            Random ran = new Random();
+            int k = ran.nextInt(batchRowCount*2);
+            R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .withExpr(String.format("id == %d", k))
+                    .addOutField(VECTOR_FIELD)
+                    .build());
+            CommonUtils.handleResponseStatus(queryR);
+            QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
+            FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
+            List<?> r = field.getFieldData();
+            if (r.isEmpty()) {
+                throw new RuntimeException("The query result is empty");
+            } else {
+                ByteBuffer outputBuf = (ByteBuffer) r.get(0);
+                ByteBuffer targetVector = encodedVectors.get(k);
+                if (!outputBuf.equals(targetVector)) {
+                    throw new RuntimeException("The query result is incorrect");
+                }
+
+                List<Float> outputVector = CommonUtils.decodeFloat16Vector(outputBuf, bfloat16);
+                List<Float> originVector = originVectors.get(k);
+                for (int j = 0; j < outputVector.size(); j++) {
+                    if (!isFloat16Eauql(outputVector.get(j), originVector.get(j), bfloat16)) {
+                        throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
+                    }
+                }
             }
         }
         System.out.println("Query result is correct");
 
-        // insert a single row
-        JsonObject row = new JsonObject();
-        row.addProperty(ID_FIELD, 9999999);
-        List<Float> newVector = CommonUtils.generateFloatVector(VECTOR_DIM);
-        ByteBuffer vector16Buf = encodeTF(newVector, bfloat16);
-        row.add(VECTOR_FIELD, gson.toJsonTree(vector16Buf.array()));
-        insertR = milvusClient.insert(InsertParam.newBuilder()
+        // drop the collection if you don't need the collection anymore
+        dropCollection();
+    }
+
+    private static void testTensorflowFloat16(boolean bfloat16) {
+        DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
+        System.out.printf("============ testTensorflowFloat16 %s ===================\n", dataType.name());
+        createCollection(bfloat16);
+
+        // Prepare tensorflow vectors, convert to ByteBuffer and insert
+        int rowCount = 10000;
+        List<Long> ids = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            ids.add(i);
+        }
+        List<InsertParam.Field> fieldsInsert = new ArrayList<>();
+        fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
+
+        List<ByteBuffer> encodedVectors;
+        if (bfloat16) {
+            List<TBfloat16> tfVectors = CommonUtils.genTensorflowBF16Vectors(VECTOR_DIM, rowCount);
+            encodedVectors = CommonUtils.encodeTensorBF16Vectors(tfVectors);
+        } else {
+            List<TFloat16> tfVectors = CommonUtils.genTensorflowFP16Vectors(VECTOR_DIM, rowCount);
+            encodedVectors = CommonUtils.encodeTensorFP16Vectors(tfVectors);
+        }
+        fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, encodedVectors));
+
+        R<MutationResult> insertR = milvusClient.insert(InsertParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
-                .withRows(Collections.singletonList(row))
+                .withFields(fieldsInsert)
                 .build());
         CommonUtils.handleResponseStatus(insertR);
+        System.out.println(ids.size() + " rows inserted");
 
-        // retrieve the single row
-        queryR = milvusClient.query(QueryParam.newBuilder()
+        // Retrieve some data and verify the output
+        Random ran = new Random();
+        int k = ran.nextInt(rowCount);
+        R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
-                .withExpr("id == 9999999")
+                .withExpr(String.format("id == %d", k))
                 .addOutField(VECTOR_FIELD)
                 .build());
         CommonUtils.handleResponseStatus(queryR);
-        queryWrapper = new QueryResultsWrapper(queryR.getData());
-        field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
-        r = field.getFieldData();
+        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
+        FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
+        List<?> r = field.getFieldData();
         if (r.isEmpty()) {
-            throw new RuntimeException("The retrieve result is empty");
-        } else {
-            ByteBuffer outBuf = (ByteBuffer) r.get(0);
-            List<Float> outVector = decodeTF(outBuf, bfloat16);
-            if (outVector.size() != newVector.size()) {
-                throw new RuntimeException("The retrieve result is incorrect");
-            }
-            for (int i = 0; i < outVector.size(); i++) {
-                if (!isFloat16Eauql(outVector.get(i), newVector.get(i), bfloat16)) {
-                    throw new RuntimeException("The retrieve result is incorrect");
-                }
-            }
+            throw new RuntimeException("The query result is empty");
         }
-        System.out.println("Retrieve result is correct");
 
-        // drop the collection if you don't need the collection anymore
-        milvusClient.dropCollection(DropCollectionParam.newBuilder()
-                .withCollectionName(COLLECTION_NAME)
-                .build());
-        System.out.println("Collection dropped");
-
-        milvusClient.close();
-    }
-
-    private static ByteBuffer encodeTF(List<Float> vector, boolean bfloat16) {
-        ByteBuffer buf = ByteBuffer.allocate(vector.size() * 2);
-        for (Float value : vector) {
-            ByteDataBuffer bf;
-            if (bfloat16) {
-                TBfloat16 tt = TBfloat16.scalarOf(value);
-                bf = tt.asRawTensor().data();
-            } else {
-                TFloat16 tt = TFloat16.scalarOf(value);
-                bf = tt.asRawTensor().data();
-            }
-            buf.put(bf.getByte(0));
-            buf.put(bf.getByte(1));
+        ByteBuffer outputBuf = (ByteBuffer) r.get(0);
+        ByteBuffer originVector = encodedVectors.get(k);
+        if (!outputBuf.equals(originVector)) {
+            throw new RuntimeException("The query result is incorrect");
         }
-        return buf;
-    }
 
-    private static List<Float> decodeTF(ByteBuffer buf, boolean bfloat16) {
-        int dim = buf.limit()/2;
-        ByteDataBuffer bf = DataBuffers.of(buf.array());
-        List<Float> vec = new ArrayList<>();
+        List<Float> vector = new ArrayList<>();
         if (bfloat16) {
-            TBfloat16 tf = Tensor.of(TBfloat16.class, Shape.of(dim), bf);
-            for (long k = 0; k < tf.size(); k++) {
-                vec.add(tf.getFloat(k));
+            TBfloat16 tf = CommonUtils.decodeTensorBF16Vector(outputBuf);
+            for (long i = 0; i < tf.size(); i++) {
+                vector.add(tf.getFloat(i));
             }
         } else {
-            TFloat16 tf = Tensor.of(TFloat16.class, Shape.of(dim), bf);
-            for (long k = 0; k < tf.size(); k++) {
-                vec.add(tf.getFloat(k));
+            TFloat16 tf = CommonUtils.decodeTensorFP16Vector(outputBuf);
+            for (long i = 0; i < tf.size(); i++) {
+                vector.add(tf.getFloat(i));
             }
         }
+        System.out.println(vector);
+        System.out.println("Query result is correct");
 
-        return vec;
-    }
-
-    private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
-        if (bfloat16) {
-            return Math.abs(a - b) <= 0.01f;
-        } else {
-            return Math.abs(a - b) <= 0.001f;
-        }
+        // drop the collection if you don't need the collection anymore
+        dropCollection();
     }
 
-
     public static void main(String[] args) {
         testFloat16(true);
         testFloat16(false);
+
+        testTensorflowFloat16(true);
+        testTensorflowFloat16(false);
     }
 }

+ 1 - 0
examples/main/java/io/milvus/v1/GeneralExample.java

@@ -364,6 +364,7 @@ public class GeneralExample {
                 .withCollectionName(COLLECTION_NAME)
                 .withExpr(expr)
                 .withOutFields(fields)
+                .withLimit(10L)
                 .build();
         R<QueryResults> response = milvusClient.query(test);
         CommonUtils.handleResponseStatus(response);

+ 8 - 2
examples/main/java/io/milvus/v1/SparseVectorExample.java

@@ -168,9 +168,15 @@ public class SparseVectorExample {
             for (SearchResultsWrapper.IDScore score : scores) {
                 System.out.println(score);
             }
-            if (scores.get(0).getLongID() != k) {
+
+            SearchResultsWrapper.IDScore firstScore = scores.get(0);
+            if (firstScore.getLongID() != k) {
                 throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
-                        scores.get(0).getLongID(), k));
+                        firstScore.getLongID(), k));
+            }
+            SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) firstScore.get(VECTOR_FIELD);
+            if (!sparse.equals(targetVector)) {
+                throw new RuntimeException("The query result is incorrect");
             }
         }
         System.out.println("Search result is correct");

+ 198 - 0
examples/main/java/io/milvus/v2/Float16VectorExample.java

@@ -0,0 +1,198 @@
+package io.milvus.v2;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonObject;
+import io.milvus.v1.CommonUtils;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.DataType;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.AddFieldReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.collection.request.HasCollectionReq;
+import io.milvus.v2.service.vector.request.InsertReq;
+import io.milvus.v2.service.vector.request.QueryReq;
+import io.milvus.v2.service.vector.request.SearchReq;
+import io.milvus.v2.service.vector.request.data.BFloat16Vec;
+import io.milvus.v2.service.vector.request.data.BaseVector;
+import io.milvus.v2.service.vector.request.data.Float16Vec;
+import io.milvus.v2.service.vector.response.InsertResp;
+import io.milvus.v2.service.vector.response.QueryResp;
+import io.milvus.v2.service.vector.response.SearchResp;
+
+import java.nio.ByteBuffer;
+import java.util.*;
+
+
+public class Float16VectorExample {
+    private static final String COLLECTION_NAME = "java_sdk_example_float16";
+    private static final String ID_FIELD = "id";
+    private static final String FP16_VECTOR_FIELD = "fp16_vector";
+    private static final String BF16_VECTOR_FIELD = "bf16_vector";
+    private static final Integer VECTOR_DIM = 128;
+
+    private static final MilvusClientV2 milvusClient;
+    static {
+        milvusClient = new MilvusClientV2(ConnectConfig.builder()
+                .uri("http://localhost:19530")
+                .build());
+    }
+
+    private static void createCollection() {
+
+        // drop the collection if you don't need the collection anymore
+        Boolean has = milvusClient.hasCollection(HasCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+        if (has) {
+            dropCollection();
+        }
+
+        // build a collection with two vector fields
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(ID_FIELD)
+                .dataType(DataType.Int64)
+                .isPrimaryKey(Boolean.TRUE)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(FP16_VECTOR_FIELD)
+                .dataType(io.milvus.v2.common.DataType.Float16Vector)
+                .dimension(VECTOR_DIM)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(BF16_VECTOR_FIELD)
+                .dataType(io.milvus.v2.common.DataType.BFloat16Vector)
+                .dimension(VECTOR_DIM)
+                .build());
+
+        List<IndexParam> indexes = new ArrayList<>();
+        Map<String,Object> extraParams = new HashMap<>();
+        extraParams.put("nlist",64);
+        indexes.add(IndexParam.builder()
+                .fieldName(FP16_VECTOR_FIELD)
+                .indexType(IndexParam.IndexType.IVF_FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .extraParams(extraParams)
+                .build());
+        indexes.add(IndexParam.builder()
+                .fieldName(BF16_VECTOR_FIELD)
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexes)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build();
+        milvusClient.createCollection(requestCreate);
+    }
+
+    private static void prepareData(int count) {
+        List<JsonObject> rows = new ArrayList<>();
+        Gson gson = new Gson();
+        for (long i = 0; i < count; i++) {
+            JsonObject row = new JsonObject();
+            row.addProperty(ID_FIELD, i);
+            ByteBuffer buf1 = CommonUtils.generateFloat16Vector(VECTOR_DIM, false);
+            row.add(FP16_VECTOR_FIELD, gson.toJsonTree(buf1.array()));
+            ByteBuffer buf2 = CommonUtils.generateFloat16Vector(VECTOR_DIM, true);
+            row.add(BF16_VECTOR_FIELD, gson.toJsonTree(buf1.array()));
+            rows.add(row);
+        }
+
+        InsertResp insertResp = milvusClient.insert(InsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(rows)
+                .build());
+        System.out.println(insertResp.getInsertCnt() + " rows inserted");
+    }
+
+    private static void searchVectors(List<Long> taargetIDs, List<BaseVector> targetVectors, String vectorFieldName) {
+        int topK = 5;
+        SearchResp searchResp = milvusClient.search(SearchReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(targetVectors)
+                .annsField(vectorFieldName)
+                .topK(topK)
+                .outputFields(Collections.singletonList(vectorFieldName))
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build());
+
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        if (searchResults.isEmpty()) {
+            throw new RuntimeException("The search result is empty");
+        }
+
+        for (int i = 0; i < searchResults.size(); i++) {
+            List<SearchResp.SearchResult> results = searchResults.get(i);
+            if (results.size() != topK) {
+                throw new RuntimeException(String.format("The search result should contains top%d items", topK));
+            }
+
+            SearchResp.SearchResult topResult = results.get(0);
+            long id = (long) topResult.getId();
+            if (id != taargetIDs.get(i)) {
+                throw new RuntimeException("The top1 id is incorrect");
+            }
+            Map<String, Object> entity = topResult.getEntity();
+            ByteBuffer vectorBuf = (ByteBuffer) entity.get(vectorFieldName);
+            if (!vectorBuf.equals(targetVectors.get(i).getData())) {
+                throw new RuntimeException("The top1 output vector is incorrect");
+            }
+        }
+        System.out.println("Search result of float16 vector is correct");
+    }
+
+    private static void search() {
+        // retrieve some rows for search
+        List<Long> targetIDs = Arrays.asList(999L, 2024L);
+        QueryResp queryResp = milvusClient.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .filter(ID_FIELD + " in " + targetIDs)
+                .outputFields(Arrays.asList(FP16_VECTOR_FIELD, BF16_VECTOR_FIELD))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+
+        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+        if (queryResults.isEmpty()) {
+            throw new RuntimeException("The query result is empty");
+        }
+
+        List<BaseVector> targetFP16Vectors = new ArrayList<>();
+        List<BaseVector> targetBF16Vectors = new ArrayList<>();
+        for (QueryResp.QueryResult queryResult : queryResults) {
+            Map<String, Object> entity = queryResult.getEntity();
+            ByteBuffer f16VectorBuf = (ByteBuffer) entity.get(FP16_VECTOR_FIELD);
+            targetFP16Vectors.add(new Float16Vec(f16VectorBuf));
+            ByteBuffer bf16VectorBuf = (ByteBuffer) entity.get(BF16_VECTOR_FIELD);
+            targetBF16Vectors.add(new BFloat16Vec(bf16VectorBuf));
+        }
+
+        // search float16 vector
+        searchVectors(targetIDs, targetFP16Vectors, FP16_VECTOR_FIELD);
+
+        // search bfloat16 vector
+        searchVectors(targetIDs, targetBF16Vectors, BF16_VECTOR_FIELD);
+    }
+
+    private static void dropCollection() {
+        milvusClient.dropCollection(DropCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("Collection dropped");
+    }
+
+
+    public static void main(String[] args) {
+        createCollection();
+        prepareData(10000);
+        search();
+        dropCollection();
+    }
+}

+ 2 - 1
src/main/java/io/milvus/bulkwriter/common/utils/GeneratorUtils.java

@@ -103,7 +103,7 @@ public class GeneratorUtils {
         Random random = new Random();
 
         for (int i = 0; i < dim; i++) {
-            rawVector[i] = random.nextInt(2); // 生成随机的 0 或 1
+            rawVector[i] = random.nextInt(2); // random 0 or 1
         }
 
         return rawVector;
@@ -121,6 +121,7 @@ public class GeneratorUtils {
             }
         }
 
+        // binary vector doesn't care endian since each byte is independent
         ByteBuffer byteBuffer = ByteBuffer.allocate(byteCount);
         for (byte b : binaryArray) {
             byteBuffer.put(b);

+ 215 - 0
src/main/java/io/milvus/common/utils/Float16Utils.java

@@ -0,0 +1,215 @@
+package io.milvus.common.utils;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.ShortBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+public class Float16Utils {
+    /**
+     * Converts a float32 into bf16. May not produce correct values for subnormal floats.
+     *
+     * This method is copied from microsoft ONNX Runtime:
+     * https://github.com/microsoft/onnxruntime/blob/main/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java
+     */
+    public static short floatToBf16(float input) {
+        int bits = Float.floatToIntBits(input);
+        int lsb = (bits >> 16) & 1;
+        int roundingBias = 0x7fff + lsb;
+        bits += roundingBias;
+        return (short) (bits >> 16);
+    }
+
+    /**
+     * Upcasts a bf16 value stored in a short into a float32 value.
+     *
+     * This method is copied from microsoft ONNX Runtime:
+     * https://github.com/microsoft/onnxruntime/blob/main/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java
+     */
+    public static float bf16ToFloat(short input) {
+        int bits = input << 16;
+        return Float.intBitsToFloat(bits);
+    }
+
+    /**
+     * Rounds a float32 value to a fp16 stored in a short.
+     *
+     * This method is copied from microsoft ONNX Runtime:
+     * https://github.com/microsoft/onnxruntime/blob/main/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java
+     */
+    public static short floatToFp16(float input) {
+        // Port of MLAS_Float2Half from onnxruntime/core/mlas/inc/mlas_float16.h
+        int bits = Float.floatToIntBits(input);
+        final int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY);
+        final int F16_MAX = (127 + 16) << 23;
+        final int DENORM_MAGIC = ((127 - 15) + (23 - 10) + 1) << 23;
+        final int SIGN_MASK = 0x80000000;
+        final int ROUNDING_CONST = ((15 - 127) << 23) + 0xfff;
+
+        int sign = bits & SIGN_MASK;
+        // mask out sign bit
+        bits ^= sign;
+
+        short output;
+        if (bits >= F16_MAX) {
+            // Inf or NaN (all exponent bits set)
+            output = (bits > F32_INFINITY) ? (short) 0x7e00 : (short) 0x7c00;
+        } else {
+            if (bits < (113 << 23)) {
+                // Subnormal or zero
+                // use a magic value to align our 10 mantissa bits at the bottom of
+                // the float. as long as FP addition is round-to-nearest-even this
+                // just works.
+                float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(DENORM_MAGIC);
+
+                // and one integer subtract of the bias later, we have our final float!
+                output = (short) (Float.floatToIntBits(tmp) - DENORM_MAGIC);
+            } else {
+                int mant_odd = (bits >> 13) & 1; // resulting mantissa is odd
+
+                // update exponent, rounding bias part 1
+                bits += ROUNDING_CONST;
+                // rounding bias part 2
+                bits += mant_odd;
+                // take the bits!
+                output = (short) (bits >> 13);
+            }
+        }
+
+        // Add the sign back in
+        output = (short) (output | ((short) (sign >> 16)));
+
+        return output;
+    }
+
+    /**
+     * Upcasts a fp16 value stored in a short to a float32 value.
+     *
+     * This method is copied from microsoft ONNX Runtime:
+     * https://github.com/microsoft/onnxruntime/blob/main/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java
+     */
+    public static float fp16ToFloat(short input) {
+        // Port of MLAS_Half2Float from onnxruntime/core/mlas/inc/mlas_float16.h
+        final int MAGIC = 113 << 23;
+        // exponent mask after shift
+        final int SHIFTED_EXP = 0x7c00 << 13;
+
+        // exponent/mantissa bits
+        int bits = (input & 0x7fff) << 13;
+        // just the exponent
+        final int exp = SHIFTED_EXP & bits;
+        // exponent adjust
+        bits += (127 - 15) << 23;
+
+        // handle exponent special cases
+        if (exp == SHIFTED_EXP) {
+            // Inf/NaN?
+            // extra exp adjust
+            bits += (128 - 16) << 23;
+        } else if (exp == 0) {
+            // Zero/Denormal?
+            // extra exp adjust
+            bits += (1 << 23);
+            // renormalize
+            float tmp = Float.intBitsToFloat(bits) - Float.intBitsToFloat(MAGIC);
+            bits = Float.floatToIntBits(tmp);
+        }
+
+        // sign bit
+        bits |= (input & 0x8000) << 16;
+
+        return Float.intBitsToFloat(bits);
+    }
+
+    /**
+     * Rounds a float32 vector to bf16 values, and stores into a ByteBuffer.
+     */
+    public static ByteBuffer f32VectorToBf16Buffer(List<Float> vector) {
+        if (vector.isEmpty()) {
+            return null;
+        }
+
+        ByteBuffer buf = ByteBuffer.allocate(2 * vector.size());
+        buf.order(ByteOrder.LITTLE_ENDIAN); // milvus server stores fp16/bf16 vector as little endian
+        for (Float val : vector) {
+            short bf16 = floatToBf16(val);
+            buf.putShort(bf16);
+        }
+        return buf;
+    }
+
+    /**
+     * Converts a ByteBuffer to fp16 vector upcasts to float32 array.
+     */
+    public static List<Float> fp16BufferToVector(ByteBuffer buf) {
+        buf.rewind(); // reset the read position
+        List<Float> vector = new ArrayList<>();
+        ShortBuffer sbuf = buf.asShortBuffer();
+        for (int i = 0; i < sbuf.limit(); i++) {
+            float val = fp16ToFloat(sbuf.get(i));
+            vector.add(val);
+        }
+        return vector;
+    }
+
+    /**
+     * Rounds a float32 vector to fp16 values, and stores into a ByteBuffer.
+     */
+    public static ByteBuffer f32VectorToFp16Buffer(List<Float> vector) {
+        if (vector.isEmpty()) {
+            return null;
+        }
+
+        ByteBuffer buf = ByteBuffer.allocate(2 * vector.size());
+        buf.order(ByteOrder.LITTLE_ENDIAN); // milvus server stores fp16/bf16 vector as little endian
+        for (Float val : vector) {
+            short bf16 = floatToFp16(val);
+            buf.putShort(bf16);
+        }
+        return buf;
+    }
+
+    /**
+     * Converts a ByteBuffer to bf16 vector upcasts to float32 array.
+     */
+    public static List<Float> bf16BufferToVector(ByteBuffer buf) {
+        buf.rewind(); // reset the read position
+        List<Float> vector = new ArrayList<>();
+        ShortBuffer sbuf = buf.asShortBuffer();
+        for (int i = 0; i < sbuf.limit(); i++) {
+            float val = bf16ToFloat(sbuf.get(i));
+            vector.add(val);
+        }
+        return vector;
+    }
+
+    /**
+     * Stores a fp16/bf16 vector into a ByteBuffer.
+     */
+    public static ByteBuffer f16VectorToBuffer(List<Short> vector) {
+        if (vector.isEmpty()) {
+            return null;
+        }
+
+        ByteBuffer buf = ByteBuffer.allocate(2 * vector.size());
+        buf.order(ByteOrder.LITTLE_ENDIAN); // milvus server stores fp16/bf16 vector as little endian
+        for (Short val : vector) {
+            buf.putShort(val);
+        }
+        return buf;
+    }
+
+    /**
+     * Converts a ByteBuffer to a fp16/bf16 vector stored in short array.
+     */
+    public static List<Short> BufferToF16Vector(ByteBuffer buf) {
+        buf.rewind(); // reset the read position
+        List<Short> vector = new ArrayList<>();
+        ShortBuffer sbuf = buf.asShortBuffer();
+        for (int i = 0; i < sbuf.limit(); i++) {
+            vector.add(sbuf.get(i));
+        }
+        return vector;
+    }
+}

+ 10 - 3
src/main/java/io/milvus/param/ParamUtils.java

@@ -631,13 +631,16 @@ public class ParamUtils {
                 plType = PlaceholderType.FloatVector;
                 List<Float> list = (List<Float>) vector;
                 ByteBuffer buf = ByteBuffer.allocate(Float.BYTES * list.size());
-                buf.order(ByteOrder.LITTLE_ENDIAN);
+                buf.order(ByteOrder.LITTLE_ENDIAN); // most of operating systems default little endian
                 list.forEach(buf::putFloat);
 
                 byte[] array = buf.array();
                 ByteString bs = ByteString.copyFrom(array);
                 byteStrings.add(bs);
             } else if (vector instanceof ByteBuffer) {
+                // for fp16/bf16 vector, each vector is a ByteBuffer with little endian
+                // for binary vector, each vector is a ByteBuffer no matter which endian
+                // the endian of each ByteBuffer is already specified by the caller
                 plType = PlaceholderType.BinaryVector;
                 ByteBuffer buf = (ByteBuffer) vector;
                 byte[] array = buf.array();
@@ -989,7 +992,9 @@ public class ParamUtils {
                 dataType == DataType.BFloat16Vector) {
             ByteBuffer totalBuf = null;
             int dim = 0;
-            // each object is ByteBuffer
+            // for fp16/bf16 vector, each vector is a ByteBuffer with little endian
+            // for binary vector, each vector is a ByteBuffer no matter which endian
+            // no need to set totalBuf endian since it is treated as byte array
             for (Object object : objects) {
                 ByteBuffer buf = (ByteBuffer) object;
                 if (totalBuf == null) {
@@ -1019,14 +1024,16 @@ public class ParamUtils {
     }
 
     private static ByteString genSparseFloatBytes(SortedMap<Long, Float> sparse) {
+        // milvus server requires sparse vector to be transfered in little endian
         ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size());
-        buf.order(ByteOrder.LITTLE_ENDIAN); // Milvus uses little endian by default
+        buf.order(ByteOrder.LITTLE_ENDIAN);
         for (Map.Entry<Long, Float> entry : sparse.entrySet()) {
             long k = entry.getKey();
             if (k < 0 || k >= (long)Math.pow(2.0, 32.0)-1) {
                 throw new ParamException("Sparse vector index must be positive and less than 2^32-1");
             }
             // here we construct a binary from the long key
+            // milvus server requires sparse vector to be transfered in little endian
             ByteBuffer lBuf = ByteBuffer.allocate(Long.BYTES);
             lBuf.order(ByteOrder.LITTLE_ENDIAN);
             lBuf.putLong(k);

+ 1 - 1
src/main/java/io/milvus/param/collection/FieldType.java

@@ -219,7 +219,7 @@ public class FieldType {
          * @return <code>Builder</code>
          */
         public Builder withMaxCapacity(@NonNull Integer maxCapacity) {
-            if (maxCapacity <= 0 || maxCapacity >= 4096) {
+            if (maxCapacity <= 0 || maxCapacity > 4096) {
                 throw new ParamException("Array field max capacity value must be within range [1, 4096]");
             }
             this.typeParams.put(Constant.ARRAY_MAX_CAPACITY, maxCapacity.toString());

+ 5 - 0
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -218,6 +218,10 @@ public class FieldDataWrapper {
                 List<ByteBuffer> packData = new ArrayList<>();
                 for (int i = 0; i < count; ++i) {
                     ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
+                    // binary vector doesn't care endian since each byte is independent
+                    // fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes,
+                    // milvus server stores fp16/bf16 vector as little endian
+                    bf.order(ByteOrder.LITTLE_ENDIAN);
                     bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
                     packData.add(bf);
                 }
@@ -237,6 +241,7 @@ public class FieldDataWrapper {
                     long num = bf.limit()/8; // each uint+float pair is 8 bytes
                     for (long j = 0; j < num; j++) {
                         // here we convert an uint 4-bytes to a long value
+                        // milvus server requires sparse vector to be transfered in little endian
                         ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
                         pBuf.order(ByteOrder.LITTLE_ENDIAN);
                         int offset = 8*(int)j;

+ 3 - 3
src/main/java/io/milvus/v2/service/collection/CollectionService.java

@@ -58,7 +58,7 @@ public class CollectionService extends BaseService {
                 .build();
 
         FieldSchema idSchema = FieldSchema.newBuilder()
-                .setName("id")
+                .setName(request.getPrimaryFieldName())
                 .setDataType(DataType.valueOf(request.getIdType().name()))
                 .setIsPrimaryKey(Boolean.TRUE)
                 .setAutoID(request.getAutoID())
@@ -75,11 +75,11 @@ public class CollectionService extends BaseService {
                 .setEnableDynamicField(request.getEnableDynamicField())
                 .build();
 
-
         CreateCollectionRequest createCollectionRequest = CreateCollectionRequest.newBuilder()
                 .setCollectionName(request.getCollectionName())
                 .setSchema(schema.toByteString())
                 .setShardsNum(request.getNumShards())
+                .setConsistencyLevelValue(request.getConsistencyLevel().getCode())
                 .build();
 
         Status status = blockingStub.createCollection(createCollectionRequest);
@@ -88,7 +88,7 @@ public class CollectionService extends BaseService {
         //create index
         IndexParam indexParam = IndexParam.builder()
                         .metricType(IndexParam.MetricType.valueOf(request.getMetricType()))
-                        .fieldName("vector")
+                        .fieldName(request.getVectorFieldName())
                         .build();
         CreateIndexReq createIndexReq = CreateIndexReq.builder()
                         .indexParams(Collections.singletonList(indexParam))

+ 6 - 1
src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java

@@ -19,6 +19,7 @@
 
 package io.milvus.v2.service.collection.request;
 
+import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
 import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.exception.ErrorCode;
@@ -67,6 +68,9 @@ public class CreateCollectionReq {
     //private String partitionKeyField;
     private Integer numPartitions;
 
+    @Builder.Default
+    private ConsistencyLevel consistencyLevel = ConsistencyLevel.BOUNDED;
+
     @Data
     @SuperBuilder
     public static class CollectionSchema {
@@ -90,7 +94,8 @@ public class CreateCollectionReq {
                 fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity());
             } else if (addFieldReq.getDataType().equals(DataType.VarChar)) {
                 fieldSchema.setMaxLength(addFieldReq.getMaxLength());
-            } else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector)) {
+            } else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector) ||
+                    addFieldReq.getDataType().equals(DataType.Float16Vector) || addFieldReq.getDataType().equals(DataType.BFloat16Vector)) {
                 if (addFieldReq.getDimension() == null) {
                     throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field");
                 }

+ 10 - 0
src/main/java/io/milvus/v2/service/vector/request/data/BFloat16Vec.java

@@ -19,9 +19,11 @@
 
 package io.milvus.v2.service.vector.request.data;
 
+import io.milvus.common.utils.Float16Utils;
 import io.milvus.grpc.PlaceholderType;
 
 import java.nio.ByteBuffer;
+import java.util.List;
 
 public class BFloat16Vec implements BaseVector {
     private final ByteBuffer data;
@@ -33,6 +35,14 @@ public class BFloat16Vec implements BaseVector {
         this.data = ByteBuffer.wrap(data);
     }
 
+    /**
+     * Construct a bfloat16 vector by a float32 array.
+     * Note that all the float32 values will be cast to bfloat16 values and store into ByteBuffer.
+     */
+    public BFloat16Vec(List<Float> data) {
+        this.data = Float16Utils.f32VectorToBf16Buffer(data);
+    }
+
     @Override
     public PlaceholderType getPlaceholderType() {
         return PlaceholderType.BFloat16Vector;

+ 11 - 0
src/main/java/io/milvus/v2/service/vector/request/data/Float16Vec.java

@@ -19,9 +19,11 @@
 
 package io.milvus.v2.service.vector.request.data;
 
+import io.milvus.common.utils.Float16Utils;
 import io.milvus.grpc.PlaceholderType;
 
 import java.nio.ByteBuffer;
+import java.util.List;
 
 public class Float16Vec implements BaseVector {
     private final ByteBuffer data;
@@ -32,6 +34,15 @@ public class Float16Vec implements BaseVector {
     public Float16Vec(byte[] data) {
         this.data = ByteBuffer.wrap(data);
     }
+
+    /**
+     * Construct a float16 vector by a float32 array.
+     * Note that all the float32 values will be cast to float16 values and store into ByteBuffer.
+     */
+    public Float16Vec(List<Float> data) {
+        this.data = Float16Utils.f32VectorToFp16Buffer(data);
+    }
+
     @Override
     public PlaceholderType getPlaceholderType() {
         return PlaceholderType.Float16Vector;

+ 224 - 0
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -28,6 +28,7 @@ import io.milvus.bulkwriter.common.clientenum.BulkFileType;
 import io.milvus.bulkwriter.common.utils.GeneratorUtils;
 import io.milvus.bulkwriter.common.utils.ParquetReaderUtils;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
+import io.milvus.common.utils.Float16Utils;
 import io.milvus.exception.ParamException;
 import io.milvus.grpc.*;
 import io.milvus.orm.iterator.QueryIterator;
@@ -51,6 +52,7 @@ import io.milvus.param.index.GetIndexStateParam;
 import io.milvus.param.partition.GetPartitionStatisticsParam;
 import io.milvus.param.partition.ShowPartitionsParam;
 import io.milvus.response.*;
+
 import org.apache.avro.generic.GenericData;
 import org.apache.commons.text.RandomStringGenerator;
 import org.apache.logging.log4j.LogManager;
@@ -75,6 +77,8 @@ class MilvusClientDockerTest {
     protected static MilvusClient client;
     protected static RandomStringGenerator generator;
     protected static final int dimension = 128;
+    protected static final float FLOAT16_PRECISION = 0.001f;
+    protected static final float BFLOAT16_PRECISION = 0.01f;
 
     protected static final Gson GSON_INSTANCE = new Gson();
 
@@ -897,6 +901,226 @@ class MilvusClientDockerTest {
         Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
     }
 
+    @Test
+    void testFloat16Vector() {
+        String randomCollectionName = generator.generate(10);
+
+        // collection schema
+        String field1Name = "id";
+        String field2Name = "float16_vector";
+        String field3Name = "bfloat16_vector";
+        FieldType field1 = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.Int64)
+                .withName(field1Name)
+                .build();
+
+        FieldType field2 = FieldType.newBuilder()
+                .withDataType(DataType.Float16Vector)
+                .withName(field2Name)
+                .withDimension(dimension)
+                .build();
+
+        FieldType field3 = FieldType.newBuilder()
+                .withDataType(DataType.BFloat16Vector)
+                .withName(field3Name)
+                .withDimension(dimension)
+                .build();
+
+        // create collection
+        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .addFieldType(field1)
+                .addFieldType(field2)
+                .addFieldType(field3)
+                .build();
+
+        R<RpcStatus> createR = client.createCollection(createParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        // create index
+        R<RpcStatus> createIndexR = client.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFieldName(field2Name)
+                .withIndexType(IndexType.FLAT)
+                .withMetricType(MetricType.COSINE)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
+
+        createIndexR = client.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFieldName(field3Name)
+                .withIndexType(IndexType.IVF_FLAT)
+                .withMetricType(MetricType.COSINE)
+                .withExtraParam("{\"nlist\": 128}")
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
+
+        // load collection
+        R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
+
+        // generate vectors
+        int rowCount = 10000;
+        List<List<Float>> vectors = generateFloatVectors(rowCount);
+
+        // insert by column-based
+        List<ByteBuffer> fp16Vectors = new ArrayList<>();
+        List<ByteBuffer> bf16Vectors = new ArrayList<>();
+        List<Long> ids = new ArrayList<>();
+        for (int i = 0; i < 5000; i++) {
+            ids.add((long)i);
+            List<Float> vector = vectors.get(i);
+            ByteBuffer fp16Vector = Float16Utils.f32VectorToFp16Buffer(vector);
+            fp16Vectors.add(fp16Vector);
+            ByteBuffer bf16Vector = Float16Utils.f32VectorToBf16Buffer(vector);
+            bf16Vectors.add(bf16Vector);
+        }
+
+        List<InsertParam.Field> fields = new ArrayList<>();
+        fields.add(new InsertParam.Field(field1Name, ids));
+        fields.add(new InsertParam.Field(field2Name, fp16Vectors));
+        fields.add(new InsertParam.Field(field3Name, bf16Vectors));
+
+        R<MutationResult> insertColumnResp = client.insert(InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFields(fields)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), insertColumnResp.getStatus().intValue());
+        System.out.println(ids.size() + " rows inserted");
+
+        // insert by row-based
+        List<JsonObject> rows = new ArrayList<>();
+        for (int i = 0; i < 5000; i++) {
+            JsonObject row = new JsonObject();
+            row.addProperty(field1Name, i + 5000);
+
+            List<Float> vector = vectors.get(i + 5000);
+            ByteBuffer fp16Vector = Float16Utils.f32VectorToFp16Buffer(vector);
+            row.add(field2Name, GSON_INSTANCE.toJsonTree(fp16Vector.array()));
+            ByteBuffer bf16Vector = Float16Utils.f32VectorToBf16Buffer(vector);
+            row.add(field3Name, GSON_INSTANCE.toJsonTree(bf16Vector.array()));
+            rows.add(row);
+        }
+
+        insertColumnResp = client.insert(InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withRows(rows)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), insertColumnResp.getStatus().intValue());
+        System.out.println(rows.size() + " rows inserted");
+
+        // query
+        List<Long> targetIDs = Arrays.asList(100L, 8888L);
+        String expr = String.format("%s in %s", field1Name, targetIDs);
+        R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr(expr)
+                .addOutField(field2Name)
+                .addOutField(field3Name)
+                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
+
+        // verify query result
+        QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
+        List<QueryResultsWrapper.RowRecord> records = fetchWrapper.getRowRecords();
+        Assertions.assertEquals(targetIDs.size(), records.size());
+        for (int i = 0; i < records.size(); i++) {
+            QueryResultsWrapper.RowRecord record = records.get(i);
+            Assertions.assertEquals(targetIDs.get(i), record.get(field1Name));
+            Assertions.assertInstanceOf(ByteBuffer.class, record.get(field2Name));
+            Assertions.assertInstanceOf(ByteBuffer.class, record.get(field3Name));
+
+            List<Float> originVector = vectors.get(targetIDs.get(i).intValue());
+            ByteBuffer buf1 = (ByteBuffer) record.get(field2Name);
+            List<Float> fp16Vec = Float16Utils.fp16BufferToVector(buf1);
+            Assertions.assertEquals(fp16Vec.size(), originVector.size());
+            for (int k = 0; k < fp16Vec.size(); k++) {
+                Assertions.assertTrue(Math.abs(fp16Vec.get(k) - originVector.get(k)) <= FLOAT16_PRECISION);
+            }
+
+            ByteBuffer buf2 = (ByteBuffer) record.get(field3Name);
+            List<Float> bf16Vec = Float16Utils.bf16BufferToVector(buf2);
+            Assertions.assertEquals(bf16Vec.size(), originVector.size());
+            for (int k = 0; k < bf16Vec.size(); k++) {
+                Assertions.assertTrue(Math.abs(bf16Vec.get(k) - originVector.get(k)) <= BFLOAT16_PRECISION);
+            }
+        }
+
+        // search float16 vector
+        long targetID = new Random().nextInt(rowCount);
+        List<Float> originVector = vectors.get((int) targetID);
+        ByteBuffer fp16Vector = Float16Utils.f32VectorToFp16Buffer(originVector);
+
+        int topK = 5;
+        R<SearchResults> searchR = client.search(SearchParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withMetricType(MetricType.COSINE)
+                .withTopK(topK)
+                .withFloat16Vectors(Collections.singletonList(fp16Vector))
+                .withVectorFieldName(field2Name)
+                .addOutField(field2Name)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+
+        // verify the search result of float16
+        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
+        System.out.println("The result of float16 vector(ID = " + targetID + "):");
+        System.out.println(scores);
+        Assertions.assertEquals(topK, scores.size());
+        Assertions.assertEquals(targetID, scores.get(0).getLongID());
+
+        Object v = scores.get(0).get(field2Name);
+        Assertions.assertInstanceOf(ByteBuffer.class, v);
+        List<Float> fp16Vec = Float16Utils.fp16BufferToVector((ByteBuffer)v);
+        Assertions.assertEquals(fp16Vec.size(), originVector.size());
+        for (int k = 0; k < fp16Vec.size(); k++) {
+            Assertions.assertTrue(Math.abs(fp16Vec.get(k) - originVector.get(k)) <= FLOAT16_PRECISION);
+        }
+
+        // search bfloat16 vector
+        ByteBuffer bf16Vector = Float16Utils.f32VectorToBf16Buffer(vectors.get((int) targetID));
+        searchR = client.search(SearchParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withMetricType(MetricType.COSINE)
+                .withTopK(topK)
+                .withParams("{\"nprobe\": 16}")
+                .withBFloat16Vectors(Collections.singletonList(bf16Vector))
+                .withVectorFieldName(field3Name)
+                .addOutField(field3Name)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+
+        // verify the search result of bfloat16
+        results = new SearchResultsWrapper(searchR.getData().getResults());
+        scores = results.getIDScore(0);
+        System.out.println("The result of bfloat16 vector(ID = " + targetID + "):");
+        System.out.println(scores);
+        Assertions.assertEquals(topK, scores.size());
+        Assertions.assertEquals(targetID, scores.get(0).getLongID());
+
+        v = scores.get(0).get(field3Name);
+        Assertions.assertInstanceOf(ByteBuffer.class, v);
+        List<Float> bf16Vec = Float16Utils.bf16BufferToVector((ByteBuffer)v);
+        Assertions.assertEquals(bf16Vec.size(), originVector.size());
+        for (int k = 0; k < bf16Vec.size(); k++) {
+            Assertions.assertTrue(Math.abs(bf16Vec.get(k) - originVector.get(k)) <= BFLOAT16_PRECISION);
+        }
+
+        // drop collection
+        DropCollectionParam dropParam = DropCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build();
+
+        R<RpcStatus> dropR = client.dropCollection(dropParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
+    }
+
     @Test
     void testMultipleVectorFields() {
         String randomCollectionName = generator.generate(10);

+ 150 - 1
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -22,6 +22,7 @@ package io.milvus.v2.client;
 import com.google.gson.*;
 
 import com.google.gson.reflect.TypeToken;
+import io.milvus.common.utils.Float16Utils;
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
 import io.milvus.v2.common.IndexParam;
@@ -40,6 +41,7 @@ import io.milvus.v2.service.vector.request.*;
 import io.milvus.v2.service.vector.request.data.*;
 import io.milvus.v2.service.vector.request.ranker.*;
 import io.milvus.v2.service.vector.response.*;
+import io.netty.buffer.ByteBuf;
 import org.apache.commons.text.RandomStringGenerator;
 
 import org.junit.jupiter.api.Assertions;
@@ -64,7 +66,7 @@ class MilvusClientV2DockerTest {
     private static final Random RANDOM = new Random();
 
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:2.4-20240605-443197bd-amd64");
+    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.4");
 
     @BeforeAll
     public static void setUp() {
@@ -117,6 +119,16 @@ class MilvusClientV2DockerTest {
 
     }
 
+    private ByteBuffer generateFloat16Vector() {
+        List<Float> vector = generateFolatVector();
+        return Float16Utils.f32VectorToFp16Buffer(vector);
+    }
+
+    private ByteBuffer generateBFloat16Vector() {
+        List<Float> vector = generateFolatVector();
+        return Float16Utils.f32VectorToBf16Buffer(vector);
+    }
+
     private SortedMap<Long, Float> generateSparseVector() {
         SortedMap<Long, Float> sparse = new TreeMap<>();
         int dim = RANDOM.nextInt(10) + 10;
@@ -319,6 +331,16 @@ class MilvusClientV2DockerTest {
                         row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector.array()));
                         break;
                     }
+                    case Float16Vector: {
+                        ByteBuffer vector = generateFloat16Vector();
+                        row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector.array()));
+                        break;
+                    }
+                    case BFloat16Vector: {
+                        ByteBuffer vector = generateBFloat16Vector();
+                        row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector.array()));
+                        break;
+                    }
                     case SparseFloatVector: {
                         SortedMap<Long, Float> vector = generateSparseVector();
                         row.add(field.getName(), GSON_INSTANCE.toJsonTree(vector));
@@ -618,6 +640,133 @@ class MilvusClientV2DockerTest {
         client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
     }
 
+    @Test
+    void testFloat16Vectors() {
+        String randomCollectionName = generator.generate(10);
+
+        // build a collection with two vector fields
+        String float16Field = "float16_vector";
+        String bfloat16Field = "bfloat16_vector";
+        CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(float16Field)
+                .dataType(DataType.Float16Vector)
+                .dimension(dimension)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(bfloat16Field)
+                .dataType(DataType.BFloat16Vector)
+                .dimension(dimension)
+                .build());
+
+        List<IndexParam> indexes = new ArrayList<>();
+        Map<String,Object> extraParams = new HashMap<>();
+        extraParams.put("nlist",64);
+        indexes.add(IndexParam.builder()
+                .fieldName(float16Field)
+                .indexType(IndexParam.IndexType.IVF_FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .extraParams(extraParams)
+                .build());
+        indexes.add(IndexParam.builder()
+                .fieldName(bfloat16Field)
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexes)
+                .build();
+        client.createCollection(requestCreate);
+
+        // insert 10000 rows
+        long count = 10000;
+        List<JsonObject> data = generateRandomData(collectionSchema, count);
+        InsertResp insertResp = client.insert(InsertReq.builder()
+                .collectionName(randomCollectionName)
+                .data(data)
+                .build());
+        Assertions.assertEquals(count, insertResp.getInsertCnt());
+
+        // update one row
+        long targetID = 99;
+        JsonObject row = data.get((int)targetID);
+        List<Float> originVector = new ArrayList<>();
+        for (int i = 0; i < dimension; ++i) {
+            originVector.add((float)1/(i+1));
+        }
+        System.out.println("Original float32 vector: " + originVector);
+        row.add(float16Field, GSON_INSTANCE.toJsonTree(Float16Utils.f32VectorToFp16Buffer(originVector).array()));
+        row.add(bfloat16Field, GSON_INSTANCE.toJsonTree(Float16Utils.f32VectorToBf16Buffer(originVector).array()));
+
+        UpsertResp upsertResp = client.upsert(UpsertReq.builder()
+                .collectionName(randomCollectionName)
+                .data(Collections.singletonList(row))
+                .build());
+        Assertions.assertEquals(1L, upsertResp.getUpsertCnt());
+
+        int topk = 10;
+        // search the float16 vector field
+        {
+            SearchResp searchResp = client.search(SearchReq.builder()
+                    .collectionName(randomCollectionName)
+                    .annsField(float16Field)
+                    .data(Collections.singletonList(new Float16Vec(originVector)))
+                    .topK(topk)
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .outputFields(Collections.singletonList(float16Field))
+                    .build());
+            List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+            Assertions.assertEquals(1, searchResults.size());
+            List<SearchResp.SearchResult> results = searchResults.get(0);
+            Assertions.assertEquals(topk, results.size());
+            SearchResp.SearchResult firstResult = results.get(0);
+            Assertions.assertEquals(targetID, (long) firstResult.getId());
+            Map<String, Object> entity = firstResult.getEntity();
+            Assertions.assertInstanceOf(ByteBuffer.class, entity.get(float16Field));
+            ByteBuffer outputBuf = (ByteBuffer) entity.get(float16Field);
+            List<Float> outputVector = Float16Utils.fp16BufferToVector(outputBuf);
+            for (int i = 0; i < outputVector.size(); i++) {
+                Assertions.assertEquals(originVector.get(i), outputVector.get(i), 0.001f);
+            }
+            System.out.println("Output float16 vector: " + outputVector);
+        }
+
+        // search the bfloat16 vector field
+        {
+            SearchResp searchResp = client.search(SearchReq.builder()
+                    .collectionName(randomCollectionName)
+                    .annsField(bfloat16Field)
+                    .data(Collections.singletonList(new BFloat16Vec(originVector)))
+                    .topK(topk)
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .outputFields(Collections.singletonList(bfloat16Field))
+                    .build());
+            List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+            Assertions.assertEquals(1, searchResults.size());
+            List<SearchResp.SearchResult> results = searchResults.get(0);
+            Assertions.assertEquals(topk, results.size());
+            SearchResp.SearchResult firstResult = results.get(0);
+            Assertions.assertEquals(targetID, (long) firstResult.getId());
+            Map<String, Object> entity = firstResult.getEntity();
+            Assertions.assertInstanceOf(ByteBuffer.class, entity.get(bfloat16Field));
+            ByteBuffer outputBuf = (ByteBuffer) entity.get(bfloat16Field);
+            List<Float> outputVector = Float16Utils.bf16BufferToVector(outputBuf);
+            for (int i = 0; i < outputVector.size(); i++) {
+                Assertions.assertEquals(originVector.get(i), outputVector.get(i), 0.01f);
+            }
+            System.out.println("Output bfloat16 vector: " + outputVector);
+        }
+
+        // get row count
+        long rowCount = getRowCount(randomCollectionName);
+        Assertions.assertEquals(count, rowCount);
+
+        client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
+    }
+
     @Test
     void testSparseVectors() {
         String randomCollectionName = generator.generate(10);