瀏覽代碼

BulkWriter supports Float16Vector/BFloat16Vector/SparseVector (#960)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 10 月之前
父節點
當前提交
235801517c

+ 205 - 163
examples/main/java/io/milvus/v1/BulkWriterExample.java

@@ -23,6 +23,7 @@ import com.fasterxml.jackson.dataformat.csv.CsvMapper;
 import com.fasterxml.jackson.dataformat.csv.CsvSchema;
 import com.google.common.collect.Lists;
 import com.google.gson.Gson;
+import com.google.gson.JsonElement;
 import com.google.gson.JsonObject;
 import com.google.gson.reflect.TypeToken;
 import io.milvus.bulkwriter.*;
@@ -58,6 +59,7 @@ import java.io.File;
 import java.io.IOException;
 import java.net.MalformedURLException;
 import java.net.URL;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
@@ -74,6 +76,8 @@ public class BulkWriterExample {
 
     private static final Gson GSON_INSTANCE = new Gson();
 
+    private static final List<Integer> QUERY_IDS = Lists.newArrayList(100, 5000);
+
 
     /**
      * If you need to transfer the files generated by bulkWriter to the corresponding remote storage (AWS S3, GCP GCS, Azure Blob, Aliyun OSS, Tencent Cloud TOS),
@@ -138,6 +142,7 @@ public class BulkWriterExample {
     private static final String SIMPLE_COLLECTION_NAME = "java_sdk_bulkwriter_simple_v1";
     private static final String ALL_TYPES_COLLECTION_NAME = "java_sdk_bulkwriter_all_v1";
     private static final Integer DIM = 512;
+    private static final Integer ARRAY_CAPACITY = 10;
     private MilvusClient milvusClient;
 
     public static void main(String[] args) throws Exception {
@@ -149,7 +154,7 @@ public class BulkWriterExample {
         );
 
         exampleSimpleCollection(exampleBulkWriter, fileTypes);
-        exampleAllTypeCollectionRemote(exampleBulkWriter, fileTypes);
+        exampleAllTypesCollectionRemote(exampleBulkWriter, fileTypes);
 
         // to call cloud import api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
         // exampleCloudImport();
@@ -167,7 +172,7 @@ public class BulkWriterExample {
     }
 
     private static void exampleSimpleCollection(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
-        CollectionSchemaParam collectionSchema = exampleBulkWriter.buildSimpleCollection();
+        CollectionSchemaParam collectionSchema = exampleBulkWriter.buildSimpleSchema();
         exampleBulkWriter.createCollection(SIMPLE_COLLECTION_NAME, collectionSchema, false);
 
         for (BulkFileType fileType : fileTypes) {
@@ -182,31 +187,23 @@ public class BulkWriterExample {
         parallelAppend(collectionSchema);
     }
 
-    private static void exampleAllTypeCollectionRemote(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
-        // float vectors + all scalar types, use bulkInsert interface
-        for (BulkFileType fileType : fileTypes) {
-            CollectionSchemaParam collectionSchema = buildAllTypeSchema(false, true);
-            List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(false, collectionSchema, fileType);
-            exampleBulkWriter.callBulkInsert(collectionSchema, batchFiles);
-            exampleBulkWriter.retrieveImportData(false);
-        }
-
-        // binary vectors + all scalar types, use bulkInsert interface
+    private static void exampleAllTypesCollectionRemote(BulkWriterExample exampleBulkWriter, List<BulkFileType> fileTypes) throws Exception {
+        // 4 types vectors + all scalar types + dynamic field enabled, use bulkInsert interface
         for (BulkFileType fileType : fileTypes) {
-            CollectionSchemaParam collectionSchema = buildAllTypeSchema(true, true);
-            List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(true, collectionSchema, fileType);
+            CollectionSchemaParam collectionSchema = buildAllTypesSchema();
+            List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(collectionSchema, fileType);
             exampleBulkWriter.callBulkInsert(collectionSchema, batchFiles);
-            exampleBulkWriter.retrieveImportData(true);
+            exampleBulkWriter.retrieveImportData();
         }
 
-//        // float vectors + all scalar types, use cloud import api.
+//        // 4 types vectors + all scalar types + dynamic field enabled, use cloud import api.
 //        // You need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
 //        for (BulkFileType fileType : fileTypes) {
-//            CollectionSchemaParam collectionSchema = buildAllTypeSchema(false, true);
-//            List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(false, collectionSchema, fileType);
+//            CollectionSchemaParam collectionSchema = buildAllTypesSchema();
+//            List<List<String>> batchFiles = exampleBulkWriter.allTypesRemoteWriter(collectionSchema, fileType);
 //            exampleBulkWriter.createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false);
 //            exampleBulkWriter.callCloudImport(batchFiles, ALL_TYPES_COLLECTION_NAME);
-//            exampleBulkWriter.retrieveImportData(false);
+//            exampleBulkWriter.retrieveImportData();
 //        }
     }
 
@@ -355,8 +352,8 @@ public class BulkWriterExample {
         }
     }
 
-    private List<List<String>> allTypesRemoteWriter(boolean binVec, CollectionSchemaParam collectionSchema, BulkFileType fileType) throws Exception {
-        System.out.printf("\n===================== all field types (%s) binary_vector=%s ====================%n", fileType.name(), binVec);
+    private List<List<String>> allTypesRemoteWriter(CollectionSchemaParam collectionSchema, BulkFileType fileType) throws Exception {
+        System.out.printf("\n===================== all field types (%s) ====================%n", fileType.name());
 
         try (RemoteBulkWriter remoteBulkWriter = buildRemoteBulkWriter(collectionSchema, fileType)) {
             System.out.println("Append rows");
@@ -377,19 +374,29 @@ public class BulkWriterExample {
                 rowObject.addProperty("json", String.format("{\"dummy\": %s, \"ok\": \"name_%s\"}", i, i));
 
                 // vector field
-                rowObject.add("vector",
-                        binVec ? GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorBinaryVector(128).array()) :
-                                GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorFloatValue(128)));
+                rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateFloatVector(DIM)));
+                rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateBinaryVector(DIM).array()));
+                rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateFloat16Vector(DIM, false).array()));
+                rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(CommonUtils.generateSparseVector()));
 
                 // array field
-                rowObject.add("arrayInt64", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorLongValue(10)));
-                rowObject.add("arrayVarchar", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorVarcharValue(10, 10)));
-                rowObject.add("arrayInt8", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt8Value(10)));
-                rowObject.add("arrayInt16", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt16Value(10)));
-                rowObject.add("arrayInt32", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt32Value(10)));
-                rowObject.add("arrayFloat", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorFloatValue(10)));
-                rowObject.add("arrayDouble", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorDoubleValue(10)));
-                rowObject.add("arrayBool", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorBoolValue(10)));
+                rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorBoolValue(10)));
+                rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt8Value(10)));
+                rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt16Value(10)));
+                rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorInt32Value(10)));
+                rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorLongValue(10)));
+                rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorVarcharValue(10, 10)));
+                rowObject.add("array_float", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorFloatValue(10)));
+                rowObject.add("array_double", GSON_INSTANCE.toJsonTree(GeneratorUtils.generatorDoubleValue(10)));
+
+                // dynamic fields
+                if (collectionSchema.isEnableDynamicField()) {
+                    rowObject.addProperty("dynamic", "dynamic_" + i);
+                }
+
+                if (QUERY_IDS.contains(i)) {
+                    System.out.println(rowObject);
+                }
 
                 remoteBulkWriter.appendRow(rowObject);
             }
@@ -580,41 +587,91 @@ public class BulkWriterExample {
         System.out.printf("Collection %s created%n", collectionName);
     }
 
-    private void retrieveImportData(boolean binVec) {
-        createIndex(binVec);
+    private void retrieveImportData() {
+        createIndex();
 
-        List<Integer> ids = Lists.newArrayList(100, 5000);
-        System.out.printf("Load collection and query items %s%n", ids);
+        System.out.printf("Load collection and query items %s%n", QUERY_IDS);
         loadCollection();
 
-        String expr = String.format("id in %s", ids);
+        String expr = String.format("id in %s", QUERY_IDS);
         System.out.println(expr);
 
-        List<QueryResultsWrapper.RowRecord> rowRecords = query(expr, Lists.newArrayList("*", "vector"));
+        List<QueryResultsWrapper.RowRecord> rowRecords = query(expr, Lists.newArrayList("*"));
         System.out.println("Query results:");
         for (QueryResultsWrapper.RowRecord record : rowRecords) {
-            System.out.println(record);
+            JsonObject rowObject = new JsonObject();
+            // scalar field
+            rowObject.addProperty("id", (Long)record.get("id"));
+            rowObject.addProperty("bool", (Boolean) record.get("bool"));
+            rowObject.addProperty("int8", (Integer) record.get("int8"));
+            rowObject.addProperty("int16", (Integer) record.get("int16"));
+            rowObject.addProperty("int32", (Integer) record.get("int32"));
+            rowObject.addProperty("float", (Float) record.get("float"));
+            rowObject.addProperty("double", (Double) record.get("double"));
+            rowObject.addProperty("varchar", (String) record.get("varchar"));
+            rowObject.add("json", (JsonElement) record.get("json"));
+
+            // vector field
+            rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(record.get("float_vector")));
+            rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(((ByteBuffer)record.get("binary_vector")).array()));
+            rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(((ByteBuffer)record.get("float16_vector")).array()));
+            rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(record.get("sparse_vector")));
+
+            // array field
+            rowObject.add("array_bool", GSON_INSTANCE.toJsonTree(record.get("array_bool")));
+            rowObject.add("array_int8", GSON_INSTANCE.toJsonTree(record.get("array_int8")));
+            rowObject.add("array_int16", GSON_INSTANCE.toJsonTree(record.get("array_int16")));
+            rowObject.add("array_int32", GSON_INSTANCE.toJsonTree(record.get("array_int32")));
+            rowObject.add("array_int64", GSON_INSTANCE.toJsonTree(record.get("array_int64")));
+            rowObject.add("array_varchar", GSON_INSTANCE.toJsonTree(record.get("array_varchar")));
+            rowObject.add("array_float", GSON_INSTANCE.toJsonTree(record.get("array_float")));
+            rowObject.add("array_double", GSON_INSTANCE.toJsonTree(record.get("array_double")));
+
+            // dynamic field
+            rowObject.addProperty("dynamic", (String) record.get("dynamic"));
+
+            System.out.println(rowObject);
         }
     }
 
-    private void createIndex(boolean binVec) {
+    private void createIndex() {
         System.out.println("Create index...");
         checkMilvusClientIfExist();
-        CreateIndexParam.Builder builder = CreateIndexParam.newBuilder()
+
+        R<RpcStatus> response = milvusClient.createIndex(CreateIndexParam.newBuilder()
                 .withCollectionName(ALL_TYPES_COLLECTION_NAME)
-                .withFieldName("vector")
-                .withIndexName("index_name")
-                .withSyncMode(Boolean.TRUE);
+                .withFieldName("float_vector")
+                .withIndexType(IndexType.FLAT)
+                .withMetricType(MetricType.L2)
+                .withSyncMode(Boolean.TRUE)
+                .build());
+        ExceptionUtils.handleResponseStatus(response);
 
-        if (binVec) {
-            builder.withIndexType(IndexType.BIN_FLAT);
-            builder.withMetricType(MetricType.HAMMING);
-        } else {
-            builder.withIndexType(IndexType.FLAT);
-            builder.withMetricType(MetricType.L2);
-        }
+        response = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(ALL_TYPES_COLLECTION_NAME)
+                .withFieldName("binary_vector")
+                .withIndexType(IndexType.BIN_FLAT)
+                .withMetricType(MetricType.HAMMING)
+                .withSyncMode(Boolean.TRUE)
+                .build());
+        ExceptionUtils.handleResponseStatus(response);
 
-        R<RpcStatus> response = milvusClient.createIndex(builder.build());
+        response = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(ALL_TYPES_COLLECTION_NAME)
+                .withFieldName("float16_vector")
+                .withIndexType(IndexType.FLAT)
+                .withMetricType(MetricType.IP)
+                .withSyncMode(Boolean.TRUE)
+                .build());
+        ExceptionUtils.handleResponseStatus(response);
+
+        response = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(ALL_TYPES_COLLECTION_NAME)
+                .withFieldName("sparse_vector")
+                .withIndexType(IndexType.SPARSE_WAND)
+                .withMetricType(MetricType.IP)
+                .withSyncMode(Boolean.TRUE)
+                .build());
         ExceptionUtils.handleResponseStatus(response);
     }
 
@@ -694,7 +751,7 @@ public class BulkWriterExample {
         System.out.println(GSON_INSTANCE.toJson(listImportJobsResponse));
     }
 
-    private CollectionSchemaParam buildSimpleCollection() {
+    private CollectionSchemaParam buildSimpleSchema() {
         FieldType fieldType1 = FieldType.newBuilder()
                 .withName("id")
                 .withDataType(DataType.Int64)
@@ -722,163 +779,148 @@ public class BulkWriterExample {
                 .withMaxLength(512)
                 .build();
 
-        CollectionSchemaParam collectionSchema = CollectionSchemaParam.newBuilder()
+        return CollectionSchemaParam.newBuilder()
                 .addFieldType(fieldType1)
                 .addFieldType(fieldType2)
                 .addFieldType(fieldType3)
                 .addFieldType(fieldType4)
                 .build();
-
-        return collectionSchema;
     }
 
-    private static CollectionSchemaParam buildAllTypeSchema(boolean binVec, boolean hasArray) {
+    private static CollectionSchemaParam buildAllTypesSchema() {
+        List<FieldType> fieldTypes = new ArrayList<>();
         // scalar field
-        FieldType fieldType1 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("id")
                 .withDataType(DataType.Int64)
                 .withPrimaryKey(true)
                 .withAutoID(false)
-                .build();
+                .build());
 
-        FieldType fieldType2 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("bool")
                 .withDataType(DataType.Bool)
-                .build();
+                .build());
 
-        FieldType fieldType3 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("int8")
                 .withDataType(DataType.Int8)
-                .build();
+                .build());
 
-        FieldType fieldType4 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("int16")
                 .withDataType(DataType.Int16)
-                .build();
+                .build());
 
-        FieldType fieldType5 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("int32")
                 .withDataType(DataType.Int32)
-                .build();
+                .build());
 
-        FieldType fieldType6 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("float")
                 .withDataType(DataType.Float)
-                .build();
+                .build());
 
-        FieldType fieldType7 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("double")
                 .withDataType(DataType.Double)
-                .build();
+                .build());
 
-        FieldType fieldType8 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("varchar")
                 .withDataType(DataType.VarChar)
                 .withMaxLength(512)
-                .build();
+                .build());
 
-        FieldType fieldType9 = FieldType.newBuilder()
+        fieldTypes.add(FieldType.newBuilder()
                 .withName("json")
                 .withDataType(DataType.JSON)
-                .build();
+                .build());
 
-        // vector field
-        FieldType fieldType10;
-        if (binVec) {
-            fieldType10 = FieldType.newBuilder()
-                    .withName("vector")
-                    .withDataType(DataType.BinaryVector)
-                    .withDimension(128)
-                    .build();
-        } else {
-            fieldType10 = FieldType.newBuilder()
-                    .withName("vector")
-                    .withDataType(DataType.FloatVector)
-                    .withDimension(128)
-                    .build();
-        }
+        // vector fields
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("float_vector")
+                .withDataType(DataType.FloatVector)
+                .withDimension(DIM)
+                .build());
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("binary_vector")
+                .withDataType(DataType.BinaryVector)
+                .withDimension(DIM)
+                .build());
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("float16_vector")
+                .withDataType(DataType.Float16Vector)
+                .withDimension(DIM)
+                .build());
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("sparse_vector")
+                .withDataType(DataType.SparseFloatVector)
+                .build());
 
-        CollectionSchemaParam.Builder schemaBuilder = CollectionSchemaParam.newBuilder()
-                .withEnableDynamicField(false)
-                .addFieldType(fieldType1)
-                .addFieldType(fieldType2)
-                .addFieldType(fieldType3)
-                .addFieldType(fieldType4)
-                .addFieldType(fieldType5)
-                .addFieldType(fieldType6)
-                .addFieldType(fieldType7)
-                .addFieldType(fieldType8)
-                .addFieldType(fieldType9)
-                .addFieldType(fieldType10);
-
-        // array field
-        if (hasArray) {
-            FieldType fieldType11 = FieldType.newBuilder()
-                    .withName("arrayInt64")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Int64)
-                    .withMaxCapacity(10)
-                    .build();
+        // array fields
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_bool")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Bool)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType12 = FieldType.newBuilder()
-                    .withName("arrayVarchar")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.VarChar)
-                    .withMaxLength(10)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_int8")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Int8)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType13 = FieldType.newBuilder()
-                    .withName("arrayInt8")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Int8)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_int16")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Int16)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType14 = FieldType.newBuilder()
-                    .withName("arrayInt16")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Int16)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_int32")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Int32)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType15 = FieldType.newBuilder()
-                    .withName("arrayInt32")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Int32)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_int64")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Int64)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType16 = FieldType.newBuilder()
-                    .withName("arrayFloat")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Float)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_varchar")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.VarChar)
+                .withMaxLength(512)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType17 = FieldType.newBuilder()
-                    .withName("arrayDouble")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Double)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_float")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Float)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
 
-            FieldType fieldType18 = FieldType.newBuilder()
-                    .withName("arrayBool")
-                    .withDataType(DataType.Array)
-                    .withElementType(DataType.Bool)
-                    .withMaxCapacity(10)
-                    .build();
+        fieldTypes.add(FieldType.newBuilder()
+                .withName("array_double")
+                .withDataType(DataType.Array)
+                .withElementType(DataType.Double)
+                .withMaxCapacity(ARRAY_CAPACITY)
+                .build());
+
+        CollectionSchemaParam.Builder schemaBuilder = CollectionSchemaParam.newBuilder()
+                .withEnableDynamicField(true)
+                .withFieldTypes(fieldTypes);
 
-            schemaBuilder.addFieldType(fieldType11)
-                    .addFieldType(fieldType12)
-                    .addFieldType(fieldType13)
-                    .addFieldType(fieldType14)
-                    .addFieldType(fieldType15)
-                    .addFieldType(fieldType16)
-                    .addFieldType(fieldType17)
-                    .addFieldType(fieldType18);
-        }
         return schemaBuilder.build();
     }
 

+ 19 - 7
src/main/java/io/milvus/bulkwriter/Buffer.java

@@ -44,6 +44,7 @@ import java.nio.ByteBuffer;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.SortedMap;
 import java.util.stream.Collectors;
 
 import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME;
@@ -217,8 +218,13 @@ public class Buffer {
                 addFloatArray(group, paramName, (List<Float>) value);
                 break;
             case BinaryVector:
+            case Float16Vector:
+            case BFloat16Vector:
                 addBinaryVector(group, paramName, (ByteBuffer) value);
                 break;
+            case SparseFloatVector:
+                addSparseVector(group, paramName, (SortedMap<Long, Float>) value);
+                break;
             case Array:
                 DataType elementType = fieldType.getElementType();
                 switch (elementType) {
@@ -279,28 +285,34 @@ public class Buffer {
         }
     }
 
-    private static void addBinaryVector(Group group, String fieldName, ByteBuffer byteBuffer) {
+    private static void addDoubleArray(Group group, String fieldName, List<Double> values) {
         Group arrayGroup = group.addGroup(fieldName);
-        byte[] bytes = byteBuffer.array();
-        for (byte value : bytes) {
+        for (double value : values) {
             Group addGroup = arrayGroup.addGroup(0);
             addGroup.add(0, value);
         }
     }
 
-    private static void addDoubleArray(Group group, String fieldName, List<Double> values) {
+    private static void addBooleanArray(Group group, String fieldName, List<Boolean> values) {
         Group arrayGroup = group.addGroup(fieldName);
-        for (double value : values) {
+        for (boolean value : values) {
             Group addGroup = arrayGroup.addGroup(0);
             addGroup.add(0, value);
         }
     }
 
-    private static void addBooleanArray(Group group, String fieldName, List<Boolean> values) {
+    private static void addBinaryVector(Group group, String fieldName, ByteBuffer byteBuffer) {
         Group arrayGroup = group.addGroup(fieldName);
-        for (boolean value : values) {
+        byte[] bytes = byteBuffer.array();
+        for (byte value : bytes) {
             Group addGroup = arrayGroup.addGroup(0);
             addGroup.add(0, value);
         }
     }
+
+    private static void addSparseVector(Group group, String fieldName, SortedMap<Long, Float> sparse) {
+        // sparse vector is parsed as JSON format string in the server side
+        String jsonString = GSON_INSTANCE.toJson(sparse);
+        group.append(fieldName, jsonString);
+    }
 }

+ 19 - 7
src/main/java/io/milvus/bulkwriter/BulkWriter.java

@@ -143,7 +143,10 @@ public abstract class BulkWriter {
             DataType dataType = fieldType.getDataType();
             switch (dataType) {
                 case BinaryVector:
-                case FloatVector: {
+                case FloatVector:
+                case Float16Vector:
+                case BFloat16Vector:
+                case SparseFloatVector: {
                     Pair<Object, Integer> objectAndSize = verifyVector(obj, fieldType);
                     rowValues.put(fieldName, objectAndSize.getLeft());
                     rowSize += objectAndSize.getRight();
@@ -216,13 +219,22 @@ public abstract class BulkWriter {
     }
 
     private Pair<Object, Integer> verifyVector(JsonElement object, FieldType fieldType) {
-        if (fieldType.getDataType() == DataType.FloatVector) {
-            Object vector = ParamUtils.checkFieldValue(fieldType, object);
-            return Pair.of(vector, ((List<?>)vector).size() * 4);
-        } else {
-            Object vector = ParamUtils.checkFieldValue(fieldType, object);
-            return Pair.of(vector, ((ByteBuffer)vector).position());
+        Object vector = ParamUtils.checkFieldValue(fieldType, object);
+        DataType dataType = fieldType.getDataType();
+        switch (dataType) {
+            case FloatVector:
+                return Pair.of(vector, ((List<?>) vector).size() * 4);
+            case BinaryVector:
+                return Pair.of(vector, ((ByteBuffer)vector).limit());
+            case Float16Vector:
+            case BFloat16Vector:
+                return Pair.of(vector, ((ByteBuffer)vector).limit() * 2);
+            case SparseFloatVector:
+                return Pair.of(vector, ((SortedMap<Long, Float>)vector).size() * 12);
+            default:
+                ExceptionUtils.throwUnExpectedException("Unknown vector type");
         }
+        return null;
     }
 
     private Pair<Object, Integer> verifyVarchar(JsonElement object, FieldType fieldType) {

+ 3 - 0
src/main/java/io/milvus/bulkwriter/common/utils/ParquetUtils.java

@@ -45,6 +45,8 @@ public class ParquetUtils {
                             .named(fieldType.getName());
                     break;
                 case BinaryVector:
+                case Float16Vector:
+                case BFloat16Vector:
                     messageTypeBuilder.requiredList()
                             .requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false))
                             .named(fieldType.getName());
@@ -59,6 +61,7 @@ public class ParquetUtils {
                     break;
                 case VarChar:
                 case JSON:
+                case SparseFloatVector: // sparse vector is parsed as JSON format string in the server side
                     messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType())
                             .named(fieldType.getName());
                     break;

+ 34 - 6
src/main/java/io/milvus/param/ParamUtils.java

@@ -648,9 +648,10 @@ public class ParamUtils {
                 byteStrings.add(bs);
             } else if (vector instanceof SortedMap) {
                 plType = PlaceholderType.SparseFloatVector;
-                SortedMap<Long, Float> map = (SortedMap<Long, Float>) vector;
-                ByteString bs = genSparseFloatBytes(map);
-                byteStrings.add(bs);
+                SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) vector;
+                ByteBuffer buf = encodeSparseFloatVector(sparse);
+                ByteString byteString = ByteString.copyFrom(buf.array());
+                byteStrings.add(byteString);
             } else {
                 String msg = "Search target vector type is illegal." +
                         " Only allow List<Float> for FloatVector," +
@@ -1023,7 +1024,7 @@ public class ParamUtils {
         throw new ParamException("Illegal vector dataType:" + dataType);
     }
 
-    private static ByteString genSparseFloatBytes(SortedMap<Long, Float> sparse) {
+    public static ByteBuffer encodeSparseFloatVector(SortedMap<Long, Float> sparse) {
         // milvus server requires sparse vector to be transfered in little endian
         ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size());
         buf.order(ByteOrder.LITTLE_ENDIAN);
@@ -1047,7 +1048,33 @@ public class ParamUtils {
             buf.putFloat(entry.getValue());
         }
 
-        return ByteString.copyFrom(buf.array());
+        return buf;
+    }
+
+    public static SortedMap<Long, Float> decodeSparseFloatVector(ByteBuffer buf) {
+        buf.order(ByteOrder.LITTLE_ENDIAN);
+        SortedMap<Long, Float> sparse = new TreeMap<>();
+        long num = buf.limit()/8; // each uint+float pair is 8 bytes
+        for (long j = 0; j < num; j++) {
+            // here we convert an uint 4-bytes to a long value
+            // milvus server requires sparse vector to be transfered in little endian
+            ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
+            pBuf.order(ByteOrder.LITTLE_ENDIAN);
+            int offset = 8*(int)j;
+            byte[] aa = buf.array();
+            for (int k = offset; k < offset + 4; k++) {
+                pBuf.put(aa[k]); // fill the first 4 bytes with the unit bytes
+            }
+            pBuf.putInt(0); // fill the last 4 bytes to zero
+            pBuf.rewind(); // reset position to head
+            long k = pBuf.getLong(); // this is the long value converted from the uint
+
+            // here we get the float value as normal
+            buf.position(offset+4); // position offsets 4 bytes since they were converted to long
+            float v = buf.getFloat(); // this is the float value
+            sparse.put(k, v);
+        }
+        return sparse;
     }
 
     private static SparseFloatArray genSparseFloatArray(List<?> objects) {
@@ -1060,7 +1087,8 @@ public class ParamUtils {
             }
             SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) object;
             dim = Math.max(dim, sparse.size());
-            ByteString byteString = genSparseFloatBytes(sparse);
+            ByteBuffer buf = encodeSparseFloatVector(sparse);
+            ByteString byteString = ByteString.copyFrom(buf.array());
             builder.addContents(byteString);
         }
 

+ 1 - 22
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -246,28 +246,7 @@ public class FieldDataWrapper {
                 for (int i = 0; i < sparseArray.getContentsCount(); ++i) {
                     ByteString bs = sparseArray.getContents(i);
                     ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray());
-                    bf.order(ByteOrder.LITTLE_ENDIAN);
-                    SortedMap<Long, Float> sparse = new TreeMap<>();
-                    long num = bf.limit()/8; // each uint+float pair is 8 bytes
-                    for (long j = 0; j < num; j++) {
-                        // here we convert an uint 4-bytes to a long value
-                        // milvus server requires sparse vector to be transfered in little endian
-                        ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
-                        pBuf.order(ByteOrder.LITTLE_ENDIAN);
-                        int offset = 8*(int)j;
-                        byte[] aa = bf.array();
-                        for (int k = offset; k < offset + 4; k++) {
-                            pBuf.put(aa[k]); // fill the first 4 bytes with the unit bytes
-                        }
-                        pBuf.putInt(0); // fill the last 4 bytes to zero
-                        pBuf.rewind(); // reset position to head
-                        long k = pBuf.getLong(); // this is the long value converted from the uint
-
-                        // here we get the float value as normal
-                        bf.position(offset+4); // position offsets 4 bytes since they were converted to long
-                        float v = bf.getFloat(); // this is the float value
-                        sparse.put(k, v);
-                    }
+                    SortedMap<Long, Float> sparse = ParamUtils.decodeSparseFloatVector(bf);
                     packData.add(sparse);
                 }
                 return packData;

+ 7 - 0
src/main/java/io/milvus/response/QueryResultsWrapper.java

@@ -159,6 +159,13 @@ public class QueryResultsWrapper extends RowRecordWrapper {
             return obj;
         }
 
+        /**
+         * Test if a key exists
+         */
+        public boolean contains(String keyName) {
+            return fieldValues.containsKey(keyName);
+        }
+
         /**
          * Constructs a <code>String</code> by {@link QueryResultsWrapper.RowRecord} instance.
          *

+ 7 - 0
src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -317,6 +317,13 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             return obj;
         }
 
+        /**
+         * Test if a key exists
+         */
+        public boolean contains(String keyName) {
+            return fieldValues.containsKey(keyName);
+        }
+
         @Override
         public String toString() {
             List<String> pairs = new ArrayList<>();

File diff suppressed because it is too large
+ 458 - 222
src/test/java/io/milvus/client/MilvusClientDockerTest.java


+ 4 - 4
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -236,14 +236,14 @@ class MilvusClientV2DockerTest {
                 }
                 return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
             }
-            case Int8: {
-                List<Integer> values = new ArrayList<>();
+            case Int8:
+            case Int16: {
+                List<Short> values = new ArrayList<>();
                 for (int i = 0; i < eleCnt; i++) {
-                    values.add(RANDOM.nextInt(256));
+                    values.add((short)RANDOM.nextInt(256));
                 }
                 return GSON_INSTANCE.toJsonTree(values).getAsJsonArray();
             }
-            case Int16:
             case Int32: {
                 List<Integer> values = new ArrayList<>();
                 for (int i = 0; i < eleCnt; i++) {

Some files were not shown because too many files changed in this diff