Browse Source

Support sparse vector (#831)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 year ago
parent
commit
5718086ed0

+ 200 - 0
examples/main/java/io/milvus/BinaryVectorExample.java

@@ -0,0 +1,200 @@
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
+import io.milvus.grpc.*;
+import io.milvus.param.*;
+import io.milvus.param.collection.*;
+import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.QueryParam;
+import io.milvus.param.dml.SearchParam;
+import io.milvus.param.index.CreateIndexParam;
+import io.milvus.response.FieldDataWrapper;
+import io.milvus.response.QueryResultsWrapper;
+import io.milvus.response.SearchResultsWrapper;
+
+import java.nio.ByteBuffer;
+import java.util.*;
+
+public class BinaryVectorExample {
+    private static final String COLLECTION_NAME = "java_sdk_example_sparse";
+    private static final String ID_FIELD = "id";
+    private static final String VECTOR_FIELD = "vector";
+
+    private static final Integer VECTOR_DIM = 512;
+
+    private static List<ByteBuffer> generateVectors(int count) {
+        Random ran = new Random();
+        List<ByteBuffer> vectors = new ArrayList<>();
+        int byteCount = VECTOR_DIM / 8;
+        for (int n = 0; n < count; ++n) {
+            ByteBuffer vector = ByteBuffer.allocate(byteCount);
+            for (int i = 0; i < byteCount; ++i) {
+                vector.put((byte) ran.nextInt(Byte.MAX_VALUE));
+            }
+            vectors.add(vector);
+        }
+        return vectors;
+
+    }
+
+    private static void handleResponseStatus(R<?> r) {
+        if (r.getStatus() != R.Status.Success.getCode()) {
+            throw new RuntimeException(r.getMessage());
+        }
+    }
+
+    public static void main(String[] args) {
+        // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
+        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .build());
+
+        // drop the collection if you don't need the collection anymore
+        R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        handleResponseStatus(hasR);
+        if (hasR.getData()) {
+            milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .build());
+        }
+
+        // Define fields
+        List<FieldType> fieldsSchema = Arrays.asList(
+                FieldType.newBuilder()
+                        .withName(ID_FIELD)
+                        .withDataType(DataType.Int64)
+                        .withPrimaryKey(true)
+                        .withAutoID(false)
+                        .build(),
+                FieldType.newBuilder()
+                        .withName(VECTOR_FIELD)
+                        .withDataType(DataType.BinaryVector)
+                        .withDimension(VECTOR_DIM)
+                        .build()
+        );
+
+        // Create the collection
+        R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                .withFieldTypes(fieldsSchema)
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Collection created");
+
+        // Insert entities
+        int rowCount = 10000;
+        List<Long> ids = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            ids.add(i);
+        }
+        List<ByteBuffer> vectors = generateVectors(rowCount);
+
+        List<InsertParam.Field> fieldsInsert = new ArrayList<>();
+        fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
+        fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
+
+        InsertParam insertParam = InsertParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFields(fieldsInsert)
+                .build();
+
+        R<MutationResult> insertR = milvusClient.insert(insertParam);
+        handleResponseStatus(insertR);
+
+        // Flush the data to storage for testing purpose
+        // Note that no need to manually call flush interface in practice
+        R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
+                addCollectionName(COLLECTION_NAME).
+                build());
+        handleResponseStatus(flushR);
+        System.out.println("Entities inserted");
+
+        // Specify an index type on the vector field.
+        ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFieldName(VECTOR_FIELD)
+                .withIndexType(IndexType.BIN_IVF_FLAT)
+                .withMetricType(MetricType.HAMMING)
+                .withExtraParam("{\"nlist\":64}")
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Index created");
+
+        // Call loadCollection() to enable automatically loading data into memory for searching
+        ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Collection loaded");
+
+        // Pick some vectors from the inserted vectors to search
+        // Ensure the returned top1 item's ID should be equal to target vector's ID
+        for (int i = 0; i < 10; i++) {
+            Random ran = new Random();
+            int k = ran.nextInt(rowCount);
+            ByteBuffer targetVector = vectors.get(k);
+            R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .withMetricType(MetricType.HAMMING)
+                    .withTopK(3)
+                    .withVectors(Collections.singletonList(targetVector))
+                    .withVectorFieldName(VECTOR_FIELD)
+                    .addOutField(VECTOR_FIELD)
+                    .withParams("{\"nprobe\":16}")
+                    .build());
+            handleResponseStatus(searchRet);
+
+            // The search() allows multiple target vectors to search in a batch.
+            // Here we only input one vector to search, get the result of No.0 vector to check
+            SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
+            List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
+            System.out.printf("The result of No.%d target vector:\n", i);
+            for (SearchResultsWrapper.IDScore score : scores) {
+                System.out.printf("ID: %d, Score: %f, Vector: ", score.getLongID(), score.getScore());
+                ByteBuffer vector = (ByteBuffer)score.get(VECTOR_FIELD);
+                vector.rewind();
+                while (vector.hasRemaining()) {
+                    System.out.print(Integer.toBinaryString(vector.get()));
+                }
+                System.out.println();
+            }
+            if (scores.get(0).getLongID() != k) {
+                throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
+                        scores.get(0).getLongID(), k));
+            }
+        }
+        System.out.println("Search result is correct");
+
+        // Retrieve some data
+        int n = 99;
+        R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withExpr(String.format("id == %d", n))
+                .addOutField(VECTOR_FIELD)
+                .build());
+        handleResponseStatus(queryR);
+        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
+        FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
+        List<?> r = field.getFieldData();
+        if (r.isEmpty()) {
+            throw new RuntimeException("The query result is empty");
+        } else {
+            ByteBuffer vector = (ByteBuffer) r.get(0);
+            if (vector.compareTo(vectors.get(n)) != 0) {
+                throw new RuntimeException("The query result is incorrect");
+            }
+        }
+        System.out.println("Query result is correct");
+
+        // drop the collection if you don't need the collection anymore
+        milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("Collection dropped");
+
+        milvusClient.close();
+    }
+}

+ 26 - 2
examples/main/java/io/milvus/Float16Example.java → examples/main/java/io/milvus/Float16VectorExample.java

@@ -35,7 +35,7 @@ import org.tensorflow.ndarray.buffer.ByteDataBuffer;
 import org.tensorflow.types.*;
 import org.tensorflow.types.*;
 
 
 
 
-public class Float16Example {
+public class Float16VectorExample {
     private static final String COLLECTION_NAME = "java_sdk_example_float16";
     private static final String COLLECTION_NAME = "java_sdk_example_float16";
     private static final String ID_FIELD = "id";
     private static final String ID_FIELD = "id";
     private static final String VECTOR_FIELD = "vector";
     private static final String VECTOR_FIELD = "vector";
@@ -174,9 +174,10 @@ public class Float16Example {
                     .withTopK(3)
                     .withTopK(3)
                     .withVectors(Collections.singletonList(targetVector))
                     .withVectors(Collections.singletonList(targetVector))
                     .withVectorFieldName(VECTOR_FIELD)
                     .withVectorFieldName(VECTOR_FIELD)
+                    .addOutField(VECTOR_FIELD)
                     .withParams("{\"nprobe\":32}")
                     .withParams("{\"nprobe\":32}")
                     .build());
                     .build());
-            handleResponseStatus(ret);
+            handleResponseStatus(searchRet);
 
 
             // The search() allows multiple target vectors to search in a batch.
             // The search() allows multiple target vectors to search in a batch.
             // Here we only input one vector to search, get the result of No.0 vector to check
             // Here we only input one vector to search, get the result of No.0 vector to check
@@ -191,11 +192,34 @@ public class Float16Example {
                         scores.get(0).getLongID(), k));
                         scores.get(0).getLongID(), k));
             }
             }
         }
         }
+        System.out.println("Search result is correct");
+
+        // Retrieve some data
+        int n = 99;
+        R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withExpr(String.format("id == %d", n))
+                .addOutField(VECTOR_FIELD)
+                .build());
+        handleResponseStatus(queryR);
+        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
+        FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
+        List<?> r = field.getFieldData();
+        if (r.isEmpty()) {
+            throw new RuntimeException("The query result is empty");
+        } else {
+            ByteBuffer bf = (ByteBuffer) r.get(0);
+            if (!bf.equals(vectors.get(n))) {
+                throw new RuntimeException("The query result is incorrect");
+            }
+        }
+        System.out.println("Query result is correct");
 
 
         // drop the collection if you don't need the collection anymore
         // drop the collection if you don't need the collection anymore
         milvusClient.dropCollection(DropCollectionParam.newBuilder()
         milvusClient.dropCollection(DropCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withCollectionName(COLLECTION_NAME)
                 .build());
                 .build());
+        System.out.println("Collection dropped");
 
 
         milvusClient.close();
         milvusClient.close();
     }
     }

+ 191 - 0
examples/main/java/io/milvus/SparseVectorExample.java

@@ -0,0 +1,191 @@
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
+import io.milvus.grpc.*;
+import io.milvus.param.*;
+import io.milvus.param.collection.*;
+import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.QueryParam;
+import io.milvus.param.dml.SearchParam;
+import io.milvus.param.index.CreateIndexParam;
+import io.milvus.response.FieldDataWrapper;
+import io.milvus.response.QueryResultsWrapper;
+import io.milvus.response.SearchResultsWrapper;
+
+import java.nio.ByteBuffer;
+import java.util.*;
+
+public class SparseVectorExample {
+    private static final String COLLECTION_NAME = "java_sdk_example_sparse";
+    private static final String ID_FIELD = "id";
+    private static final String VECTOR_FIELD = "vector";
+
+    private static List<SortedMap<Long, Float>> generateVectors(int count) {
+        Random ran = new Random();
+        List<SortedMap<Long, Float>> vectors = new ArrayList<>();
+        for (int n = 0; n < count; ++n) {
+            SortedMap<Long, Float> sparse = new TreeMap<>();
+            int dim = ran.nextInt(10) + 1;
+            for (int i = 0; i < dim; ++i) {
+                sparse.put((long)ran.nextInt(1000000), ran.nextFloat());
+            }
+            vectors.add(sparse);
+        }
+        return vectors;
+
+    }
+
+    private static void handleResponseStatus(R<?> r) {
+        if (r.getStatus() != R.Status.Success.getCode()) {
+            throw new RuntimeException(r.getMessage());
+        }
+    }
+
+    public static void main(String[] args) {
+        // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
+        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .build());
+
+        // drop the collection if you don't need the collection anymore
+        R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        handleResponseStatus(hasR);
+        if (hasR.getData()) {
+            milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .build());
+        }
+
+        // Define fields
+        List<FieldType> fieldsSchema = Arrays.asList(
+                FieldType.newBuilder()
+                        .withName(ID_FIELD)
+                        .withDataType(DataType.Int64)
+                        .withPrimaryKey(true)
+                        .withAutoID(false)
+                        .build(),
+                FieldType.newBuilder()
+                        .withName(VECTOR_FIELD)
+                        .withDataType(DataType.SparseFloatVector)
+                        .build()
+        );
+
+        // Create the collection
+        R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                .withFieldTypes(fieldsSchema)
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Collection created");
+
+        // Insert entities
+        int rowCount = 10000;
+        List<Long> ids = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            ids.add(i);
+        }
+        List<SortedMap<Long, Float>> vectors = generateVectors(rowCount);
+
+        List<InsertParam.Field> fieldsInsert = new ArrayList<>();
+        fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
+        fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
+
+        InsertParam insertParam = InsertParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFields(fieldsInsert)
+                .build();
+
+        R<MutationResult> insertR = milvusClient.insert(insertParam);
+        handleResponseStatus(insertR);
+
+        // Flush the data to storage for testing purpose
+        // Note that no need to manually call flush interface in practice
+        R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
+                addCollectionName(COLLECTION_NAME).
+                build());
+        handleResponseStatus(flushR);
+        System.out.println("Entities inserted");
+
+        // Specify an index type on the vector field.
+        ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFieldName(VECTOR_FIELD)
+                .withIndexType(IndexType.SPARSE_WAND)
+                .withMetricType(MetricType.IP)
+                .withExtraParam("{\"drop_ratio_build\":0.2}")
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Index created");
+
+        // Call loadCollection() to enable automatically loading data into memory for searching
+        ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Collection loaded");
+
+        // Pick some vectors from the inserted vectors to search
+        // Ensure the returned top1 item's ID should be equal to target vector's ID
+        for (int i = 0; i < 10; i++) {
+            Random ran = new Random();
+            int k = ran.nextInt(rowCount);
+            SortedMap<Long, Float> targetVector = vectors.get(k);
+            R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .withMetricType(MetricType.IP)
+                    .withTopK(3)
+                    .withVectors(Collections.singletonList(targetVector))
+                    .withVectorFieldName(VECTOR_FIELD)
+                    .addOutField(VECTOR_FIELD)
+                    .withParams("{\"drop_ratio_search\":0.2}")
+                    .build());
+            handleResponseStatus(searchRet);
+
+            // The search() allows multiple target vectors to search in a batch.
+            // Here we only input one vector to search, get the result of No.0 vector to check
+            SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
+            List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
+            System.out.printf("The result of No.%d target vector:\n", i);
+            for (SearchResultsWrapper.IDScore score : scores) {
+                System.out.println(score);
+            }
+            if (scores.get(0).getLongID() != k) {
+                throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
+                        scores.get(0).getLongID(), k));
+            }
+        }
+        System.out.println("Search result is correct");
+
+        // Retrieve some data
+        int n = 99;
+        R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withExpr(String.format("id == %d", n))
+                .addOutField(VECTOR_FIELD)
+                .build());
+        handleResponseStatus(queryR);
+        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryR.getData());
+        FieldDataWrapper field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
+        List<?> r = field.getFieldData();
+        if (r.isEmpty()) {
+            throw new RuntimeException("The query result is empty");
+        } else {
+            SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) r.get(0);
+            if (!sparse.equals(vectors.get(n))) {
+                throw new RuntimeException("The query result is incorrect");
+            }
+        }
+        System.out.println("Query result is correct");
+
+        // drop the collection if you don't need the collection anymore
+        milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("Collection dropped");
+
+        milvusClient.close();
+    }
+}

+ 9 - 6
src/main/java/io/milvus/param/IndexType.java

@@ -27,7 +27,7 @@ import lombok.Getter;
  */
  */
 public enum IndexType {
 public enum IndexType {
     None(0),
     None(0),
-    //Only supported for float vectors
+    // Only supported for float vectors
     FLAT(1),
     FLAT(1),
     IVF_FLAT(2),
     IVF_FLAT(2),
     IVF_SQ8(3),
     IVF_SQ8(3),
@@ -37,19 +37,22 @@ public enum IndexType {
     AUTOINDEX(11),
     AUTOINDEX(11),
     SCANN(12),
     SCANN(12),
 
 
-    // GPU index
+    // GPU indexes only for float vectors
     GPU_IVF_FLAT(50),
     GPU_IVF_FLAT(50),
     GPU_IVF_PQ(51),
     GPU_IVF_PQ(51),
 
 
-    //Only supported for binary vectors
+    // Only supported for binary vectors
     BIN_FLAT(80),
     BIN_FLAT(80),
     BIN_IVF_FLAT(81),
     BIN_IVF_FLAT(81),
 
 
-    //Scalar field index start from here
-    //Only for varchar type field
+    // Only for varchar type field
     TRIE("Trie", 100),
     TRIE("Trie", 100),
-    //Only for scalar type field
+    // Only for scalar type field
     STL_SORT(200),
     STL_SORT(200),
+
+    // Only for sparse vectors
+    SPARSE_INVERTED_INDEX(300),
+    SPARSE_WAND(301)
     ;
     ;
 
 
     @Getter
     @Getter

+ 88 - 12
src/main/java/io/milvus/param/ParamUtils.java

@@ -46,6 +46,7 @@ public class ParamUtils {
         typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
+        typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap");
         return typeErrMsg;
         return typeErrMsg;
     }
     }
 
 
@@ -98,12 +99,11 @@ public class ParamUtils {
                         throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
                         throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
                     }
                     }
                 }
                 }
+                break;
             }
             }
-            break;
             case BinaryVector:
             case BinaryVector:
             case Float16Vector:
             case Float16Vector:
-            case BFloat16Vector:
-            {
+            case BFloat16Vector: {
                 int dim = fieldSchema.getDimension();
                 int dim = fieldSchema.getDimension();
                 for (int i = 0; i < values.size(); ++i) {
                 for (int i = 0; i < values.size(); ++i) {
                     Object value  = values.get(i);
                     Object value  = values.get(i);
@@ -120,8 +120,30 @@ public class ParamUtils {
                         throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
                         throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
                     }
                     }
                 }
                 }
+                break;
             }
             }
-            break;
+            case SparseFloatVector:
+                for (Object value : values) {
+                    if (!(value instanceof SortedMap)) {
+                        throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
+                    }
+
+                    // is SortedMap<Long, Float> ?
+                    SortedMap<?, ?> m = (SortedMap<?, ?>)value;
+                    if (m.isEmpty()) { // not allow empty value for sparse vector
+                        String msg = "Not allow empty SortedMap for sparse vector field '%s'";
+                        throw new ParamException(String.format(msg, fieldSchema.getName()));
+                    }
+                    if (!(m.firstKey() instanceof Long)) {
+                        String msg = "The key of SortedMap must be Long for sparse vector field '%s'";
+                        throw new ParamException(String.format(msg, fieldSchema.getName()));
+                    }
+                    if (!(m.get(m.firstKey()) instanceof Float)) {
+                        String msg = "The value of SortedMap must be Float for sparse vector field '%s'";
+                        throw new ParamException(String.format(msg, fieldSchema.getName()));
+                    }
+                }
+                break;
             case Int64:
             case Int64:
                 for (Object value : values) {
                 for (Object value : values) {
                     if (!(value instanceof Long)) {
                     if (!(value instanceof Long)) {
@@ -286,7 +308,8 @@ public class ParamUtils {
             }
             }
             if (isPartitionKeyEnabled) {
             if (isPartitionKeyEnabled) {
                 if (partitionName != null && !partitionName.isEmpty()) {
                 if (partitionName != null && !partitionName.isEmpty()) {
-                    String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name";
+                    String msg = String.format("Collection %s has partition key, not allow to specify partition name",
+                            requestParam.getCollectionName());
                     throw new ParamException(msg);
                     throw new ParamException(msg);
                 }
                 }
             } else if (partitionName != null) {
             } else if (partitionName != null) {
@@ -314,7 +337,8 @@ public class ParamUtils {
                 for (InsertParam.Field field : fields) {
                 for (InsertParam.Field field : fields) {
                     if (field.getName().equals(fieldType.getName())) {
                     if (field.getName().equals(fieldType.getName())) {
                         if (fieldType.isAutoID()) {
                         if (fieldType.isAutoID()) {
-                            String msg = "The primary key: " + fieldType.getName() + " is auto generated, no need to input.";
+                            String msg = String.format("The primary key: %s is auto generated, no need to input.",
+                                    fieldType.getName());
                             throw new ParamException(msg);
                             throw new ParamException(msg);
                         }
                         }
                         checkFieldData(fieldType, field);
                         checkFieldData(fieldType, field);
@@ -326,8 +350,7 @@ public class ParamUtils {
 
 
                 }
                 }
                 if (!found && !fieldType.isAutoID()) {
                 if (!found && !fieldType.isAutoID()) {
-                    String msg = "The field: " + fieldType.getName() + " is not provided.";
-                    throw new ParamException(msg);
+                    throw new ParamException(String.format("The field: %s is not provided.", fieldType.getName()));
                 }
                 }
             }
             }
 
 
@@ -369,7 +392,7 @@ public class ParamUtils {
                     Object rowFieldData = row.get(fieldName);
                     Object rowFieldData = row.get(fieldName);
                     if (rowFieldData != null) {
                     if (rowFieldData != null) {
                         if (fieldType.isAutoID()) {
                         if (fieldType.isAutoID()) {
-                            String msg = "The primary key: " + fieldName + " is auto generated, no need to input.";
+                            String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
                             throw new ParamException(msg);
                             throw new ParamException(msg);
                         }
                         }
                         checkFieldData(fieldType, Lists.newArrayList(rowFieldData), false);
                         checkFieldData(fieldType, Lists.newArrayList(rowFieldData), false);
@@ -379,7 +402,7 @@ public class ParamUtils {
                     } else {
                     } else {
                         // check if autoId
                         // check if autoId
                         if (!fieldType.isAutoID()) {
                         if (!fieldType.isAutoID()) {
-                            String msg = "The field: " + fieldType.getName() + " is not provided.";
+                            String msg = String.format("The field: %s is not provided.", fieldType.getName());
                             throw new ParamException(msg);
                             throw new ParamException(msg);
                         }
                         }
                     }
                     }
@@ -455,8 +478,16 @@ public class ParamUtils {
                 byte[] array = buf.array();
                 byte[] array = buf.array();
                 ByteString bs = ByteString.copyFrom(array);
                 ByteString bs = ByteString.copyFrom(array);
                 byteStrings.add(bs);
                 byteStrings.add(bs);
+            } else if (vector instanceof SortedMap) {
+                plType = PlaceholderType.SparseFloatVector;
+                SortedMap<Long, Float> map = (SortedMap<Long, Float>) vector;
+                ByteString bs = genSparseFloatBytes(map);
+                byteStrings.add(bs);
             } else {
             } else {
-                String msg = "Search target vector type is illegal(Only allow List<Float> or ByteBuffer)";
+                String msg = "Search target vector type is illegal." +
+                        " Only allow List<Float> for FloatVector," +
+                        " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
+                        " List<SortedMap<Long, Float>> for SparseFloatVector.";
                 throw new ParamException(msg);
                 throw new ParamException(msg);
             }
             }
         }
         }
@@ -623,6 +654,7 @@ public class ParamUtils {
             add(DataType.BinaryVector);
             add(DataType.BinaryVector);
             add(DataType.Float16Vector);
             add(DataType.Float16Vector);
             add(DataType.BFloat16Vector);
             add(DataType.BFloat16Vector);
+            add(DataType.SparseFloatVector);
         }};
         }};
         return vectorDataType.contains(dataType);
         return vectorDataType.contains(dataType);
     }
     }
@@ -631,7 +663,6 @@ public class ParamUtils {
         return genFieldData(fieldType, objects, Boolean.FALSE);
         return genFieldData(fieldType, objects, Boolean.FALSE);
     }
     }
 
 
-    @SuppressWarnings("unchecked")
     private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
     private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
         if (objects == null) {
         if (objects == null) {
             throw new ParamException("Cannot generate FieldData from null object");
             throw new ParamException("Cannot generate FieldData from null object");
@@ -694,11 +725,56 @@ public class ParamUtils {
             } else {
             } else {
                 return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
                 return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
             }
             }
+        } else if  (dataType == DataType.SparseFloatVector) {
+            SparseFloatArray sparseArray = genSparseFloatArray(objects);
+            return VectorField.newBuilder().setDim(sparseArray.getDim()).setSparseFloatVector(sparseArray).build();
         }
         }
 
 
         throw new ParamException("Illegal vector dataType:" + dataType);
         throw new ParamException("Illegal vector dataType:" + dataType);
     }
     }
 
 
+    private static ByteString genSparseFloatBytes(SortedMap<Long, Float> sparse) {
+        ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size());
+        buf.order(ByteOrder.LITTLE_ENDIAN); // Milvus uses little endian by default
+        for (Map.Entry<Long, Float> entry : sparse.entrySet()) {
+            long k = entry.getKey();
+            if (k < 0 || k >= (long)Math.pow(2.0, 32.0)-1) {
+                throw new ParamException("Sparse vector index must be positive and less than 2^32-1");
+            }
+            // here we construct a binary from the long key
+            ByteBuffer lBuf = ByteBuffer.allocate(Long.BYTES);
+            lBuf.order(ByteOrder.LITTLE_ENDIAN);
+            lBuf.putLong(k);
+            // the server requires a binary of unsigned int, append the first 4 bytes
+            buf.put(lBuf.array(), 0, 4);
+
+            float v = entry.getValue();
+            if (Float.isNaN(v) || Float.isInfinite(v)) {
+                throw new ParamException("Sparse vector value cannot be NaN or Infinite");
+            }
+            buf.putFloat(entry.getValue());
+        }
+
+        return ByteString.copyFrom(buf.array());
+    }
+
+    private static SparseFloatArray genSparseFloatArray(List<?> objects) {
+        int dim = 0; // the real dim is unknown, set the max size as dim
+        SparseFloatArray.Builder builder = SparseFloatArray.newBuilder();
+        // each object must be SortedMap<Long, Float>, which is already validated by checkFieldData()
+        for (Object object : objects) {
+            if (!(object instanceof SortedMap)) {
+                throw new ParamException("SparseFloatVector vector field's value type must be SortedMap");
+            }
+            SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) object;
+            dim = Math.max(dim, sparse.size());
+            ByteString byteString = genSparseFloatBytes(sparse);
+            builder.addContents(byteString);
+        }
+
+        return builder.setDim(dim).build();
+    }
+
     private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
     private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
         if (fieldType.getDataType() == DataType.Array) {
         if (fieldType.getDataType() == DataType.Array) {
             ArrayArray.Builder builder = ArrayArray.newBuilder();
             ArrayArray.Builder builder = ArrayArray.newBuilder();

+ 1 - 0
src/main/java/io/milvus/param/QueryNodeSingleSearch.java

@@ -101,6 +101,7 @@ public class QueryNodeSingleSearch {
          * @param vectors list of target vectors:
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float
          *                if vector type is FloatVector, vectors is List of List Float
          *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer
          *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer
+         *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float]
          * @return <code>Builder</code>
          * @return <code>Builder</code>
          */
          */
         public Builder withVectors(@NonNull List<?> vectors) {
         public Builder withVectors(@NonNull List<?> vectors) {

+ 2 - 1
src/main/java/io/milvus/param/collection/FieldType.java

@@ -271,7 +271,8 @@ public class FieldType {
                 throw new ParamException("String type is not supported, use Varchar instead");
                 throw new ParamException("String type is not supported, use Varchar instead");
             }
             }
 
 
-            if (ParamUtils.isVectorDataType(dataType)) {
+            // SparseVector has no dimension, other vector types must have dimension
+            if (ParamUtils.isVectorDataType(dataType) && dataType != DataType.SparseFloatVector) {
                 if (!typeParams.containsKey(Constant.VECTOR_DIM)) {
                 if (!typeParams.containsKey(Constant.VECTOR_DIM)) {
                     throw new ParamException("Vector field dimension must be specified");
                     throw new ParamException("Vector field dimension must be specified");
                 }
                 }

+ 2 - 0
src/main/java/io/milvus/param/dml/InsertParam.java

@@ -218,7 +218,9 @@ public class InsertParam {
      * If dataType is Varchar, values is List of String;
      * If dataType is Varchar, values is List of String;
      * If dataType is FloatVector, values is List of List Float;
      * If dataType is FloatVector, values is List of List Float;
      * If dataType is BinaryVector/Float16Vector/BFloat16Vector, values is List of ByteBuffer;
      * If dataType is BinaryVector/Float16Vector/BFloat16Vector, values is List of ByteBuffer;
+     * If dataType is SparseFloatVector, values is List of SortedMap[Long, Float];
      * If dataType is Array, values can be List of List Boolean/Integer/Short/Long/Float/Double/String;
      * If dataType is Array, values can be List of List Boolean/Integer/Short/Long/Float/Double/String;
+     * If dataType is JSON, values is List of JSONObject;
      *
      *
      * Note:
      * Note:
      * If dataType is Int8/Int16/Int32, values is List of Integer or Short
      * If dataType is Int8/Int16/Int32, values is List of Integer or Short

+ 15 - 1
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -32,6 +32,7 @@ import lombok.ToString;
 
 
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.List;
+import java.util.SortedMap;
 
 
 /**
 /**
  * Parameters for <code>search</code> interface.
  * Parameters for <code>search</code> interface.
@@ -238,6 +239,7 @@ public class SearchParam {
          * @param vectors list of target vectors:
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float;
          *                if vector type is FloatVector, vectors is List of List Float;
          *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
          *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
+         *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
          * @return <code>Builder</code>
          * @return <code>Builder</code>
          */
          */
         public Builder withVectors(@NonNull List<?> vectors) {
         public Builder withVectors(@NonNull List<?> vectors) {
@@ -310,6 +312,7 @@ public class SearchParam {
 
 
             if (vectors.get(0) instanceof List) {
             if (vectors.get(0) instanceof List) {
                 // float vectors
                 // float vectors
+                // TODO: here only check the first element, potential risk
                 List<?> first = (List<?>) vectors.get(0);
                 List<?> first = (List<?>) vectors.get(0);
                 if (!(first.get(0) instanceof Float)) {
                 if (!(first.get(0) instanceof Float)) {
                     throw new ParamException("Float vector field's value must be Lst<Float>");
                     throw new ParamException("Float vector field's value must be Lst<Float>");
@@ -324,6 +327,7 @@ public class SearchParam {
                 }
                 }
             } else if (vectors.get(0) instanceof ByteBuffer) {
             } else if (vectors.get(0) instanceof ByteBuffer) {
                 // binary vectors
                 // binary vectors
+                // TODO: here only check the first element, potential risk
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
                 int dim = first.position();
                 int dim = first.position();
                 for (int i = 1; i < vectors.size(); ++i) {
                 for (int i = 1; i < vectors.size(); ++i) {
@@ -332,8 +336,18 @@ public class SearchParam {
                         throw new ParamException("Target vector dimension must be equal");
                         throw new ParamException("Target vector dimension must be equal");
                     }
                     }
                 }
                 }
+            } else if (vectors.get(0) instanceof SortedMap) {
+                // sparse vectors, must be SortedMap<Long, Float>
+                // TODO: here only check the first element, potential risk
+                SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);
+
+
             } else {
             } else {
-                throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
+                String msg = "Search target vector type is illegal." +
+                        " Only allow List<Float> for FloatVector," +
+                        " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
+                        " List<SortedMap<Long, Float>> for SparseFloatVector.";
+                throw new ParamException(msg);
             }
             }
 
 
             return new SearchParam(this);
             return new SearchParam(this);

+ 121 - 30
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -3,18 +3,18 @@ package io.milvus.response;
 import com.alibaba.fastjson.JSONObject;
 import com.alibaba.fastjson.JSONObject;
 import com.google.protobuf.ProtocolStringList;
 import com.google.protobuf.ProtocolStringList;
 import io.milvus.exception.ParamException;
 import io.milvus.exception.ParamException;
-import io.milvus.grpc.ArrayArray;
-import io.milvus.grpc.DataType;
-import io.milvus.grpc.FieldData;
+import io.milvus.grpc.*;
 import io.milvus.exception.IllegalResponseException;
 import io.milvus.exception.IllegalResponseException;
 
 
-import io.milvus.grpc.ScalarField;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.ParamUtils;
 import lombok.NonNull;
 import lombok.NonNull;
 
 
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
+import java.util.SortedMap;
+import java.util.TreeMap;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 
 
 import com.google.protobuf.ByteString;
 import com.google.protobuf.ByteString;
@@ -56,6 +56,27 @@ public class FieldDataWrapper {
         return (int) fieldData.getVectors().getDim();
         return (int) fieldData.getVectors().getDim();
     }
     }
 
 
+    // this method returns bytes size of each vector according to vector type
+    private int checkDim(DataType dt, ByteString data, int dim) {
+        if (dt == DataType.BinaryVector) {
+            if ((data.size()*8) % dim != 0) {
+                String msg = String.format("Returned binary vector data array size %d doesn't match dimension %d",
+                        data.size(), dim);
+                throw new IllegalResponseException(msg);
+            }
+            return dim/8;
+        } else if (dt == DataType.Float16Vector || dt == DataType.BFloat16Vector) {
+            if (data.size() % (dim*2) != 0) {
+                String msg = String.format("Returned float16 vector data array size %d doesn't match dimension %d",
+                        data.size(), dim);
+                throw new IllegalResponseException(msg);
+            }
+            return dim*2;
+        }
+
+        return 0;
+    }
+
     /**
     /**
      * Gets the row count of a field.
      * Gets the row count of a field.
      * * Throws {@link IllegalResponseException} if the field type is illegal.
      * * Throws {@link IllegalResponseException} if the field type is illegal.
@@ -69,19 +90,34 @@ public class FieldDataWrapper {
                 int dim = getDim();
                 int dim = getDim();
                 List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
                 List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
                 if (data.size() % dim != 0) {
                 if (data.size() % dim != 0) {
-                    throw new IllegalResponseException("Returned float vector field data array size doesn't match dimension");
+                    String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
+                            data.size(), dim);
+                    throw new IllegalResponseException(msg);
                 }
                 }
 
 
                 return data.size()/dim;
                 return data.size()/dim;
             }
             }
             case BinaryVector: {
             case BinaryVector: {
+                // for binary vector, each dimension is one bit, each byte is 8 dim
                 int dim = getDim();
                 int dim = getDim();
                 ByteString data = fieldData.getVectors().getBinaryVector();
                 ByteString data = fieldData.getVectors().getBinaryVector();
-                if ((data.size()*8) % dim != 0) {
-                    throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
-                }
+                int bytePerVec = checkDim(dt, data, dim);
 
 
-                return (data.size()*8)/dim;
+                return data.size()/bytePerVec;
+            }
+            case Float16Vector:
+            case BFloat16Vector: {
+                // for float16 vector, each dimension 2 bytes
+                int dim = getDim();
+                ByteString data = (dt == DataType.Float16Vector) ?
+                        fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
+                int bytePerVec = checkDim(dt, data, dim);
+
+                return data.size()/bytePerVec;
+            }
+            case SparseFloatVector: {
+                // for sparse vector, each content is a vector
+                return fieldData.getVectors().getSparseFloatVector().getContentsCount();
             }
             }
             case Int64:
             case Int64:
                 return fieldData.getScalars().getLongData().getDataCount();
                 return fieldData.getScalars().getLongData().getDataCount();
@@ -109,15 +145,17 @@ public class FieldDataWrapper {
 
 
     /**
     /**
      * Returns the field data according to its type:
      * Returns the field data according to its type:
-     *      float vector field return List of List Float,
-     *      binary vector field return List of ByteBuffer
-     *      int64 field return List of Long
-     *      int32/int16/int8 field return List of Integer
-     *      boolean field return List of Boolean
-     *      float field return List of Float
-     *      double field return List of Double
-     *      varchar field return List of String
-     *      array field return List of List
+     *      FloatVector field returns List of List Float,
+     *      BinaryVector/Float16Vector/BFloat16Vector fields return List of ByteBuffer
+     *      SparseFloatVector field returns List of SortedMap[Long, Float]
+     *      Int64 field returns List of Long
+     *      Int32/Int16/Int8 fields return List of Integer
+     *      Bool field returns List of Boolean
+     *      Float field returns List of Float
+     *      Double field returns List of Double
+     *      Varchar field returns List of String
+     *      Array field returns List of List
+     *      JSON field returns List of String;
      *      etc.
      *      etc.
      *
      *
      * Throws {@link IllegalResponseException} if the field type is illegal.
      * Throws {@link IllegalResponseException} if the field type is illegal.
@@ -131,7 +169,9 @@ public class FieldDataWrapper {
                 int dim = getDim();
                 int dim = getDim();
                 List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
                 List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
                 if (data.size() % dim != 0) {
                 if (data.size() % dim != 0) {
-                    throw new IllegalResponseException("Returned float vector field data array size doesn't match dimension");
+                    String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
+                            data.size(), dim);
+                    throw new IllegalResponseException(msg);
                 }
                 }
 
 
                 List<List<Float>> packData = new ArrayList<>();
                 List<List<Float>> packData = new ArrayList<>();
@@ -141,16 +181,22 @@ public class FieldDataWrapper {
                 }
                 }
                 return packData;
                 return packData;
             }
             }
-            case BinaryVector: {
+            case BinaryVector:
+            case Float16Vector:
+            case BFloat16Vector: {
                 int dim = getDim();
                 int dim = getDim();
-                ByteString data = fieldData.getVectors().getBinaryVector();
-                if ((data.size()*8) % dim != 0) {
-                    throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
+                ByteString data = null;
+                if (dt == DataType.BinaryVector) {
+                    data = fieldData.getVectors().getBinaryVector();
+                } else if (dt == DataType.Float16Vector) {
+                    data = fieldData.getVectors().getFloat16Vector();
+                } else {
+                    data = fieldData.getVectors().getBfloat16Vector();
                 }
                 }
 
 
-                List<ByteBuffer> packData = new ArrayList<>();
-                int bytePerVec = dim/8;
+                int bytePerVec = checkDim(dt, data, dim);
                 int count = data.size()/bytePerVec;
                 int count = data.size()/bytePerVec;
+                List<ByteBuffer> packData = new ArrayList<>();
                 for (int i = 0; i < count; ++i) {
                 for (int i = 0; i < count; ++i) {
                     ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
                     ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
                     bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
                     bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
@@ -158,6 +204,40 @@ public class FieldDataWrapper {
                 }
                 }
                 return packData;
                 return packData;
             }
             }
+            case SparseFloatVector: {
+                // in Java sdk, each sparse vector is pairs of long+float
+                // in server side, each sparse vector is stored as uint+float (8 bytes)
+                // don't use sparseArray.getDim() because the dim is the max index of each rows
+                SparseFloatArray sparseArray = fieldData.getVectors().getSparseFloatVector();
+                List<SortedMap<Long, Float>> packData = new ArrayList<>();
+                for (int i = 0; i < sparseArray.getContentsCount(); ++i) {
+                    ByteString bs = sparseArray.getContents(i);
+                    ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray());
+                    bf.order(ByteOrder.LITTLE_ENDIAN);
+                    SortedMap<Long, Float> sparse = new TreeMap<>();
+                    long num = bf.limit()/8; // each uint+float pair is 8 bytes
+                    for (long j = 0; j < num; j++) {
+                        // here we convert an uint 4-bytes to a long value
+                        ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
+                        pBuf.order(ByteOrder.LITTLE_ENDIAN);
+                        int offset = 8*(int)j;
+                        byte[] aa = bf.array();
+                        for (int k = offset; k < offset + 4; k++) {
+                            pBuf.put(aa[k]); // fill the first 4 bytes with the unit bytes
+                        }
+                        pBuf.putInt(0); // fill the last 4 bytes to zero
+                        pBuf.rewind(); // reset position to head
+                        long k = pBuf.getLong(); // this is the long value converted from the uint
+
+                        // here we get the float value as normal
+                        bf.position(offset+4); // position offsets 4 bytes since they were converted to long
+                        float v = bf.getFloat(); // this is the float value
+                        sparse.put(k, v);
+                    }
+                    packData.add(sparse);
+                }
+                return packData;
+            }
             case Array:
             case Array:
                 List<List<?>> array = new ArrayList<>();
                 List<List<?>> array = new ArrayList<>();
                 ArrayArray arrArray = fieldData.getScalars().getArrayData();
                 ArrayArray arrArray = fieldData.getScalars().getArrayData();
@@ -202,7 +282,7 @@ public class FieldDataWrapper {
                 return protoStrList.subList(0, protoStrList.size());
                 return protoStrList.subList(0, protoStrList.size());
             case JSON:
             case JSON:
                 List<ByteString> dataList = scalar.getJsonData().getDataList();
                 List<ByteString> dataList = scalar.getJsonData().getDataList();
-                return dataList.stream().map(ByteString::toByteArray).collect(Collectors.toList());
+                return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
             default:
             default:
                 return new ArrayList<>();
                 return new ArrayList<>();
         }
         }
@@ -249,14 +329,25 @@ public class FieldDataWrapper {
     }
     }
 
 
     public Object valueByIdx(int index) throws ParamException {
     public Object valueByIdx(int index) throws ParamException {
-        if (index < 0 || index >= getFieldData().size()) {
-            throw new ParamException("index out of range");
+        List<?> data = getFieldData();
+        if (index < 0 || index >= data.size()) {
+            throw new ParamException(String.format("Value index %d out of range %d", index, data.size()));
         }
         }
-        return getFieldData().get(index);
+        return data.get(index);
     }
     }
 
 
     private JSONObject parseObjectData(int index) {
     private JSONObject parseObjectData(int index) {
         Object object = valueByIdx(index);
         Object object = valueByIdx(index);
-        return JSONObject.parseObject(new String((byte[])object));
+        return ParseJSONObject(object);
+    }
+
+    public static JSONObject ParseJSONObject(Object object) {
+        if (object instanceof String) {
+            return JSONObject.parseObject((String)object);
+        } else if (object instanceof byte[]) {
+            return JSONObject.parseObject(new String((byte[]) object));
+        } else {
+            throw new IllegalResponseException("Illegal type value for JSON parser");
+        }
     }
     }
 }
 }

+ 1 - 1
src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -187,7 +187,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
 
 
                         Object value = wrapper.valueByIdx((int)offset + n);
                         Object value = wrapper.valueByIdx((int)offset + n);
                         if (wrapper.isJsonField()) {
                         if (wrapper.isJsonField()) {
-                            idScores.get(n).put(field.getFieldName(), JSONObject.parseObject(new String((byte[])value)));
+                            idScores.get(n).put(field.getFieldName(), FieldDataWrapper.ParseJSONObject(value));
                         } else {
                         } else {
                             idScores.get(n).put(field.getFieldName(), value);
                             idScores.get(n).put(field.getFieldName(), value);
                         }
                         }

+ 1 - 1
src/main/java/io/milvus/response/basic/RowRecordWrapper.java

@@ -47,7 +47,7 @@ public abstract class RowRecordWrapper {
                     }
                     }
                     Object value = wrapper.valueByIdx((int)index);
                     Object value = wrapper.valueByIdx((int)index);
                     if (wrapper.isJsonField()) {
                     if (wrapper.isJsonField()) {
-                        JSONObject jsonField = JSONObject.parseObject(new String((byte[])value));
+                        JSONObject jsonField = FieldDataWrapper.ParseJSONObject(value);
                         if (wrapper.isDynamicField()) {
                         if (wrapper.isDynamicField()) {
                             for (String key: jsonField.keySet()) {
                             for (String key: jsonField.keySet()) {
                                 record.put(key, jsonField.get(key));
                                 record.put(key, jsonField.get(key));

+ 1 - 0
src/main/java/io/milvus/v2/utils/DataUtils.java

@@ -461,6 +461,7 @@ public class DataUtils {
         typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
         typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
+        typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap");
         return typeErrMsg;
         return typeErrMsg;
     }
     }
 }
 }

+ 132 - 1
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -69,7 +69,7 @@ class MilvusClientDockerTest {
     private static final int dimension = 128;
     private static final int dimension = 128;
 
 
     @Container
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.10");
+    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.0-rc.1");
 
 
     @BeforeAll
     @BeforeAll
     public static void setUp() {
     public static void setUp() {
@@ -142,6 +142,21 @@ class MilvusClientDockerTest {
 
 
     }
     }
 
 
+    private List<SortedMap<Long, Float>> generateSparseVectors(int count) {
+        Random ran = new Random();
+        List<SortedMap<Long, Float>> vectors = new ArrayList<>();
+        for (int n = 0; n < count; ++n) {
+            SortedMap<Long, Float> sparse = new TreeMap<>();
+            int dim = ran.nextInt(10) + 1;
+            for (int i = 0; i < dim; ++i) {
+                sparse.put((long)ran.nextInt(1000000), ran.nextFloat());
+            }
+            vectors.add(sparse);
+        }
+        return vectors;
+
+    }
+
     @Test
     @Test
     void testFloatVectors() {
     void testFloatVectors() {
         String randomCollectionName = generator.generate(10);
         String randomCollectionName = generator.generate(10);
@@ -665,6 +680,122 @@ class MilvusClientDockerTest {
         Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
         Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
     }
     }
 
 
+    @Test
+    void testSparseVector() {
+        String randomCollectionName = generator.generate(10);
+
+        // collection schema
+        String field1Name = "field1";
+        String field2Name = "field2";
+        FieldType field1 = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.Int64)
+                .withName(field1Name)
+                .build();
+
+        FieldType field2 = FieldType.newBuilder()
+                .withDataType(DataType.SparseFloatVector)
+                .withName(field2Name)
+                .build();
+
+        // create collection
+        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .addFieldType(field1)
+                .addFieldType(field2)
+                .build();
+
+        R<RpcStatus> createR = client.createCollection(createParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        int rowCount = 10000;
+        List<Long> ids = new ArrayList<>();
+        for (int i = 0; i < rowCount; i++) {
+            ids.add((long)i);
+        }
+        List<SortedMap<Long, Float>> vectors = generateSparseVectors(rowCount);
+        List<InsertParam.Field> fields = new ArrayList<>();
+        fields.add(new InsertParam.Field(field1Name, ids));
+        fields.add(new InsertParam.Field(field2Name, vectors));
+
+        InsertParam insertParam = InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFields(fields)
+                .build();
+
+        R<MutationResult> insertR = client.insert(insertParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
+
+        // create index
+        CreateIndexParam indexParam = CreateIndexParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFieldName(field2Name)
+                .withIndexType(IndexType.SPARSE_INVERTED_INDEX)
+                .withMetricType(MetricType.IP)
+                .withExtraParam("{\"drop_ratio_build\":0.2}")
+                .build();
+
+        R<RpcStatus> createIndexR = client.createIndex(indexParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
+
+        // load collection
+        R<RpcStatus> loadR = client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
+
+        // pick some vectors to search with index
+        int nq = 5;
+        List<Long> targetVectorIDs = new ArrayList<>();
+        List<SortedMap<Long, Float>> targetVectors = new ArrayList<>();
+        Random ran = new Random();
+        int randomIndex = ran.nextInt(rowCount);
+        for (int i = randomIndex; i < randomIndex + nq; ++i) {
+            targetVectorIDs.add(ids.get(i));
+            targetVectors.add(vectors.get(i));
+        }
+
+        System.out.println("Search target IDs:" + targetVectorIDs);
+        System.out.println("Search target vectors:" + targetVectors);
+
+        int topK = 5;
+        SearchParam searchParam = SearchParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withMetricType(MetricType.IP)
+                .withTopK(topK)
+                .withVectors(targetVectors)
+                .withVectorFieldName(field2Name)
+                .addOutField(field2Name)
+                .withParams("{\"drop_ratio_search\":0.2}")
+                .build();
+
+        R<SearchResults> searchR = client.search(searchParam);
+//        System.out.println(searchR);
+        Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+
+        // verify the search result
+        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        for (int i = 0; i < targetVectors.size(); ++i) {
+            List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
+            System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
+            System.out.println(scores);
+            Assertions.assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
+
+            Object v = scores.get(0).get(field2Name);
+            SortedMap<Long, Float> sparse = (SortedMap<Long, Float>)v;
+            Assertions.assertTrue(sparse.equals(targetVectors.get(i)));
+        }
+
+        // drop collection
+        DropCollectionParam dropParam = DropCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build();
+
+        R<RpcStatus> dropR = client.dropCollection(dropParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue());
+    }
+
     @Test
     @Test
     void testAsyncMethods() {
     void testAsyncMethods() {
         String randomCollectionName = generator.generate(10);
         String randomCollectionName = generator.generate(10);