Browse Source

Explicitly specify vector type when search (#848)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 year ago
parent
commit
b80814b89c

+ 1 - 1
examples/main/java/io/milvus/BinaryVectorExample.java

@@ -140,7 +140,7 @@ public class BinaryVectorExample {
                     .withCollectionName(COLLECTION_NAME)
                     .withCollectionName(COLLECTION_NAME)
                     .withMetricType(MetricType.HAMMING)
                     .withMetricType(MetricType.HAMMING)
                     .withTopK(3)
                     .withTopK(3)
-                    .withVectors(Collections.singletonList(targetVector))
+                    .withBinaryVectors(Collections.singletonList(targetVector))
                     .withVectorFieldName(VECTOR_FIELD)
                     .withVectorFieldName(VECTOR_FIELD)
                     .addOutField(VECTOR_FIELD)
                     .addOutField(VECTOR_FIELD)
                     .withParams("{\"nprobe\":16}")
                     .withParams("{\"nprobe\":16}")

+ 8 - 4
examples/main/java/io/milvus/Float16VectorExample.java

@@ -137,15 +137,19 @@ public class Float16VectorExample {
             Random ran = new Random();
             Random ran = new Random();
             int k = ran.nextInt(rowCount);
             int k = ran.nextInt(rowCount);
             ByteBuffer targetVector = vectors.get(k);
             ByteBuffer targetVector = vectors.get(k);
-            R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
+            SearchParam.Builder builder = SearchParam.newBuilder()
                     .withCollectionName(COLLECTION_NAME)
                     .withCollectionName(COLLECTION_NAME)
                     .withMetricType(MetricType.L2)
                     .withMetricType(MetricType.L2)
                     .withTopK(3)
                     .withTopK(3)
-                    .withVectors(Collections.singletonList(targetVector))
                     .withVectorFieldName(VECTOR_FIELD)
                     .withVectorFieldName(VECTOR_FIELD)
                     .addOutField(VECTOR_FIELD)
                     .addOutField(VECTOR_FIELD)
-                    .withParams("{\"nprobe\":32}")
-                    .build());
+                    .withParams("{\"nprobe\":32}");
+            if (bfloat16) {
+                builder.withBFloat16Vectors(Collections.singletonList(targetVector));
+            } else {
+                builder.withFloat16Vectors(Collections.singletonList(targetVector));
+            }
+            R<SearchResults> searchRet = milvusClient.search(builder.build());
             CommonUtils.handleResponseStatus(searchRet);
             CommonUtils.handleResponseStatus(searchRet);
 
 
             // The search() allows multiple target vectors to search in a batch.
             // The search() allows multiple target vectors to search in a batch.

+ 1 - 1
examples/main/java/io/milvus/GeneralExample.java

@@ -331,7 +331,7 @@ public class GeneralExample {
                 .withMetricType(MetricType.L2)
                 .withMetricType(MetricType.L2)
                 .withOutFields(outFields)
                 .withOutFields(outFields)
                 .withTopK(SEARCH_K)
                 .withTopK(SEARCH_K)
-                .withVectors(vectors)
+                .withFloatVectors(vectors)
                 .withVectorFieldName(VECTOR_FIELD)
                 .withVectorFieldName(VECTOR_FIELD)
                 .withExpr(expr)
                 .withExpr(expr)
                 .withParams(SEARCH_PARAM)
                 .withParams(SEARCH_PARAM)

+ 6 - 6
examples/main/java/io/milvus/HybridSearchExample.java

@@ -180,7 +180,7 @@ public class HybridSearchExample {
                 .build());
                 .build());
         CommonUtils.handleResponseStatus(resp);
         CommonUtils.handleResponseStatus(resp);
 
 
-        System.out.printf("%d entities inserted by rows", rowCount);
+        System.out.printf("%d entities inserted by rows\n", rowCount);
 
 
         // Insert entities by columns
         // Insert entities by columns
         List<Long> ids = new ArrayList<>();
         List<Long> ids = new ArrayList<>();
@@ -205,7 +205,7 @@ public class HybridSearchExample {
                 .build());
                 .build());
         CommonUtils.handleResponseStatus(resp);
         CommonUtils.handleResponseStatus(resp);
 
 
-        System.out.printf("%d entities inserted by columns", rowCount);
+        System.out.printf("%d entities inserted by columns\n", rowCount);
 
 
         milvusClient.close();
         milvusClient.close();
     }
     }
@@ -231,7 +231,7 @@ public class HybridSearchExample {
         // Note that only allow one vector for each sub request
         // Note that only allow one vector for each sub request
         AnnSearchParam req1 = AnnSearchParam.newBuilder()
         AnnSearchParam req1 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(FLOAT_VECTOR_FIELD)
                 .withVectorFieldName(FLOAT_VECTOR_FIELD)
-                .withVectors(CommonUtils.generateFloatVectors(FLOAT_VECTOR_DIM, 1))
+                .withFloatVectors(CommonUtils.generateFloatVectors(FLOAT_VECTOR_DIM, 1))
                 .withMetricType(FLOAT_VECTOR_METRIC)
                 .withMetricType(FLOAT_VECTOR_METRIC)
                 .withParams("{\"nprobe\": 32}")
                 .withParams("{\"nprobe\": 32}")
                 .withTopK(10)
                 .withTopK(10)
@@ -239,14 +239,14 @@ public class HybridSearchExample {
 
 
         AnnSearchParam req2 = AnnSearchParam.newBuilder()
         AnnSearchParam req2 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(BINARY_VECTOR_FIELD)
                 .withVectorFieldName(BINARY_VECTOR_FIELD)
-                .withVectors(CommonUtils.generateBinaryVectors(BINARY_VECTOR_DIM, 1))
+                .withBinaryVectors(CommonUtils.generateBinaryVectors(BINARY_VECTOR_DIM, 1))
                 .withMetricType(BINARY_VECTOR_METRIC)
                 .withMetricType(BINARY_VECTOR_METRIC)
                 .withTopK(15)
                 .withTopK(15)
                 .build();
                 .build();
 
 
 //        AnnSearchParam req3 = AnnSearchParam.newBuilder()
 //        AnnSearchParam req3 = AnnSearchParam.newBuilder()
 //                .withVectorFieldName(FLOAT16_VECTOR_FIELD)
 //                .withVectorFieldName(FLOAT16_VECTOR_FIELD)
-//                .withVectors(CommonUtils.generateFloat16Vectors(FLOAT16_VECTOR_DIM, 1, false))
+//                .withFloat16Vectors(CommonUtils.generateFloat16Vectors(FLOAT16_VECTOR_DIM, 1, false))
 //                .withMetricType(FLOAT16_VECTOR_METRIC)
 //                .withMetricType(FLOAT16_VECTOR_METRIC)
 //                .withParams("{\"es\":200}")
 //                .withParams("{\"es\":200}")
 //                .withTopK(20)
 //                .withTopK(20)
@@ -254,7 +254,7 @@ public class HybridSearchExample {
 
 
         AnnSearchParam req4 = AnnSearchParam.newBuilder()
         AnnSearchParam req4 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(SPARSE_VECTOR_FIELD)
                 .withVectorFieldName(SPARSE_VECTOR_FIELD)
-                .withVectors(CommonUtils.generateSparseVectors(1))
+                .withSparseFloatVectors(CommonUtils.generateSparseVectors(1))
                 .withMetricType(SPARSE_VECTOR_METRIC)
                 .withMetricType(SPARSE_VECTOR_METRIC)
                 .withParams("{\"drop_ratio_search\":0.2}")
                 .withParams("{\"drop_ratio_search\":0.2}")
                 .withTopK(20)
                 .withTopK(20)

+ 1 - 1
examples/main/java/io/milvus/SimpleExample.java

@@ -138,7 +138,7 @@ public class SimpleExample {
                 .withCollectionName(COLLECTION_NAME)
                 .withCollectionName(COLLECTION_NAME)
                 .withMetricType(MetricType.L2)
                 .withMetricType(MetricType.L2)
                 .withTopK(5)
                 .withTopK(5)
-                .withVectors(Arrays.asList(vector))
+                .withFloatVectors(Arrays.asList(vector))
                 .withVectorFieldName(VECTOR_FIELD)
                 .withVectorFieldName(VECTOR_FIELD)
                 .withParams("{}")
                 .withParams("{}")
                 .addOutField(TITLE_FIELD)
                 .addOutField(TITLE_FIELD)

+ 1 - 1
examples/main/java/io/milvus/SparseVectorExample.java

@@ -137,7 +137,7 @@ public class SparseVectorExample {
                     .withCollectionName(COLLECTION_NAME)
                     .withCollectionName(COLLECTION_NAME)
                     .withMetricType(MetricType.IP)
                     .withMetricType(MetricType.IP)
                     .withTopK(3)
                     .withTopK(3)
-                    .withVectors(Collections.singletonList(targetVector))
+                    .withSparseFloatVectors(Collections.singletonList(targetVector))
                     .withVectorFieldName(VECTOR_FIELD)
                     .withVectorFieldName(VECTOR_FIELD)
                     .addOutField(VECTOR_FIELD)
                     .addOutField(VECTOR_FIELD)
                     .withParams("{\"drop_ratio_search\":0.2}")
                     .withParams("{\"drop_ratio_search\":0.2}")

+ 8 - 3
src/main/java/io/milvus/param/ParamUtils.java

@@ -462,7 +462,7 @@ public class ParamUtils {
     }
     }
 
 
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
-    private static ByteString convertPlaceholder(List<?> vectors) throws ParamException {
+    private static ByteString convertPlaceholder(List<?> vectors, PlaceholderType placeType) throws ParamException {
         PlaceholderType plType = PlaceholderType.None;
         PlaceholderType plType = PlaceholderType.None;
         List<ByteString> byteStrings = new ArrayList<>();
         List<ByteString> byteStrings = new ArrayList<>();
         for (Object vector : vectors) {
         for (Object vector : vectors) {
@@ -496,6 +496,11 @@ public class ParamUtils {
             }
             }
         }
         }
 
 
+        // force specify PlaceholderType
+        if (placeType != PlaceholderType.None) {
+            plType = placeType;
+        }
+
         PlaceholderValue.Builder pldBuilder = PlaceholderValue.newBuilder()
         PlaceholderValue.Builder pldBuilder = PlaceholderValue.newBuilder()
                 .setTag(Constant.VECTOR_TAG)
                 .setTag(Constant.VECTOR_TAG)
                 .setType(plType);
                 .setType(plType);
@@ -522,7 +527,7 @@ public class ParamUtils {
         }
         }
 
 
         // prepare target vectors
         // prepare target vectors
-        ByteString byteStr = convertPlaceholder(requestParam.getVectors());
+        ByteString byteStr = convertPlaceholder(requestParam.getVectors(), requestParam.getPlType());
         builder.setPlaceholderGroup(byteStr);
         builder.setPlaceholderGroup(byteStr);
         builder.setNq(requestParam.getNQ());
         builder.setNq(requestParam.getNQ());
 
 
@@ -611,7 +616,7 @@ public class ParamUtils {
     public static SearchRequest convertAnnSearchParam(@NonNull AnnSearchParam annSearchParam,
     public static SearchRequest convertAnnSearchParam(@NonNull AnnSearchParam annSearchParam,
                                                       ConsistencyLevelEnum consistencyLevel) {
                                                       ConsistencyLevelEnum consistencyLevel) {
         SearchRequest.Builder builder = SearchRequest.newBuilder();
         SearchRequest.Builder builder = SearchRequest.newBuilder();
-        ByteString byteStr = convertPlaceholder(annSearchParam.getVectors());
+        ByteString byteStr = convertPlaceholder(annSearchParam.getVectors(), annSearchParam.getPlType());
         builder.setPlaceholderGroup(byteStr);
         builder.setPlaceholderGroup(byteStr);
         builder.setNq(annSearchParam.getNQ());
         builder.setNq(annSearchParam.getNQ());
 
 

+ 67 - 52
src/main/java/io/milvus/param/dml/AnnSearchParam.java

@@ -20,6 +20,7 @@
 package io.milvus.param.dml;
 package io.milvus.param.dml;
 
 
 import io.milvus.exception.ParamException;
 import io.milvus.exception.ParamException;
+import io.milvus.grpc.PlaceholderType;
 import io.milvus.param.MetricType;
 import io.milvus.param.MetricType;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.ParamUtils;
 
 
@@ -45,6 +46,7 @@ public class AnnSearchParam {
     private final List<?> vectors;
     private final List<?> vectors;
     private final Long NQ;
     private final Long NQ;
     private final String params;
     private final String params;
+    private final PlaceholderType plType;
 
 
     private AnnSearchParam(@NonNull Builder builder) {
     private AnnSearchParam(@NonNull Builder builder) {
         this.metricType = builder.metricType.name();
         this.metricType = builder.metricType.name();
@@ -54,6 +56,7 @@ public class AnnSearchParam {
         this.vectors = builder.vectors;
         this.vectors = builder.vectors;
         this.NQ = builder.NQ;
         this.NQ = builder.NQ;
         this.params = builder.params;
         this.params = builder.params;
+        this.plType = builder.plType;
     }
     }
 
 
     public static Builder newBuilder() {
     public static Builder newBuilder() {
@@ -72,6 +75,11 @@ public class AnnSearchParam {
         private Long NQ;
         private Long NQ;
         private String params = "{}";
         private String params = "{}";
 
 
+        // plType is used to distinct vector type
+        // for Float16Vector/BFloat16Vector and BinaryVector, user input ByteBuffer
+        // the server cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
+        private PlaceholderType plType = PlaceholderType.None;
+
         Builder() {
         Builder() {
         }
         }
 
 
@@ -121,18 +129,67 @@ public class AnnSearchParam {
         }
         }
 
 
         /**
         /**
-         * Sets the target vectors.
-         * Note: currently, only support one vector for search.
+         * Sets the target vectors to search on FloatVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withFloatVectors(@NonNull List<List<Float>> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.FloatVector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on BinaryVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withBinaryVectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.BinaryVector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on Float16Vector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.Float16Vector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on BFloat16Vector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withBFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.BFloat16Vector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on SparseFloatVector field.
          *
          *
-         * @param vectors list of target vectors:
-         *                if vector type is FloatVector, vectors is List of List Float;
-         *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
-         *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
+         * @param vectors target vectors to search
          * @return <code>Builder</code>
          * @return <code>Builder</code>
          */
          */
-        public Builder withVectors(@NonNull List<?> vectors) {
+        public Builder withSparseFloatVectors(@NonNull List<SortedMap<Long, Float>> vectors) {
             this.vectors = vectors;
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
             this.NQ = (long) vectors.size();
+            plType = PlaceholderType.SparseFloatVector;
             return this;
             return this;
         }
         }
 
 
@@ -164,53 +221,11 @@ public class AnnSearchParam {
                 throw new ParamException("TopK value is illegal");
                 throw new ParamException("TopK value is illegal");
             }
             }
 
 
-            if (vectors == null || vectors.isEmpty()) {
-                throw new ParamException("Target vectors can not be empty");
+            if (vectors.isEmpty()) {
+                throw new ParamException("At lease a vector is required for AnnSearchParam");
             }
             }
 
 
-            if (vectors.size() != 1) {
-                throw new ParamException("Only support one vector for each AnnSearchParam");
-            }
-
-            if (vectors.get(0) instanceof List) {
-                // float vectors
-                // TODO: here only check the first element, potential risk
-                List<?> first = (List<?>) vectors.get(0);
-                if (!(first.get(0) instanceof Float)) {
-                    throw new ParamException("Float vector field's value must be Lst<Float>");
-                }
-
-                int dim = first.size();
-                for (int i = 1; i < vectors.size(); ++i) {
-                    List<?> temp = (List<?>) vectors.get(i);
-                    if (dim != temp.size()) {
-                        throw new ParamException("Target vector dimension must be equal");
-                    }
-                }
-            } else if (vectors.get(0) instanceof ByteBuffer) {
-                // binary vectors
-                // TODO: here only check the first element, potential risk
-                ByteBuffer first = (ByteBuffer) vectors.get(0);
-                int dim = first.position();
-                for (int i = 1; i < vectors.size(); ++i) {
-                    ByteBuffer temp = (ByteBuffer) vectors.get(i);
-                    if (dim != temp.position()) {
-                        throw new ParamException("Target vector dimension must be equal");
-                    }
-                }
-            } else if (vectors.get(0) instanceof SortedMap) {
-                // sparse vectors, must be SortedMap<Long, Float>
-                // TODO: here only check the first element, potential risk
-                SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);
-
-
-            } else {
-                String msg = "Search target vector type is illegal." +
-                        " Only allow List<Float> for FloatVector," +
-                        " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
-                        " List<SortedMap<Long, Float>> for SparseFloatVector.";
-                throw new ParamException(msg);
-            }
+            SearchParam.verifyVectors(vectors);
 
 
             return new AnnSearchParam(this);
             return new AnnSearchParam(this);
         }
         }

+ 9 - 0
src/main/java/io/milvus/param/dml/HybridSearchParam.java

@@ -225,6 +225,15 @@ public class HybridSearchParam {
                 throw new ParamException("At least a search request is required");
                 throw new ParamException("At least a search request is required");
             }
             }
 
 
+            int vectorSize = 0;
+            for (AnnSearchParam req : searchRequests) {
+                if (vectorSize == 0) {
+                    vectorSize = req.getVectors().size();
+                } else if (vectorSize != req.getVectors().size()) {
+                    throw new ParamException("Vector number of each AnnSearchParam must be equal");
+                }
+            }
+
             if (topK <= 0) {
             if (topK <= 0) {
                 throw new ParamException("TopK value is illegal");
                 throw new ParamException("TopK value is illegal");
             }
             }

+ 126 - 41
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -22,6 +22,7 @@ package io.milvus.param.dml;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Lists;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.exception.ParamException;
 import io.milvus.exception.ParamException;
+import io.milvus.grpc.PlaceholderType;
 import io.milvus.param.Constant;
 import io.milvus.param.Constant;
 import io.milvus.param.MetricType;
 import io.milvus.param.MetricType;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.ParamUtils;
@@ -58,6 +59,7 @@ public class SearchParam {
     private final ConsistencyLevelEnum consistencyLevel;
     private final ConsistencyLevelEnum consistencyLevel;
     private final boolean ignoreGrowing;
     private final boolean ignoreGrowing;
     private final String groupByFieldName;
     private final String groupByFieldName;
+    private final PlaceholderType plType;
 
 
     private SearchParam(@NonNull Builder builder) {
     private SearchParam(@NonNull Builder builder) {
         this.databaseName = builder.databaseName;
         this.databaseName = builder.databaseName;
@@ -78,6 +80,7 @@ public class SearchParam {
         this.consistencyLevel = builder.consistencyLevel;
         this.consistencyLevel = builder.consistencyLevel;
         this.ignoreGrowing = builder.ignoreGrowing;
         this.ignoreGrowing = builder.ignoreGrowing;
         this.groupByFieldName = builder.groupByFieldName;
         this.groupByFieldName = builder.groupByFieldName;
+        this.plType = builder.plType;
     }
     }
 
 
     public static Builder newBuilder() {
     public static Builder newBuilder() {
@@ -107,6 +110,11 @@ public class SearchParam {
         private Boolean ignoreGrowing = Boolean.FALSE;
         private Boolean ignoreGrowing = Boolean.FALSE;
         private String groupByFieldName;
         private String groupByFieldName;
 
 
+        // plType is used to distinct vector type
+        // for Float16Vector/BFloat16Vector and BinaryVector, user input ByteBuffer
+        // the server cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
+        private PlaceholderType plType = PlaceholderType.None;
+
         Builder() {
         Builder() {
         }
         }
 
 
@@ -238,6 +246,10 @@ public class SearchParam {
 
 
         /**
         /**
          * Sets the target vectors.
          * Sets the target vectors.
+         * Note: Deprecated in v2.4.0, for the reason that the server cannot know a ByteBuffer
+         *       is a BinarVector or Float16Vector/BFloat16Vector.
+         *       Replaced by withFloatVectors/withBinaryVectors/withFloat16Vectors/withBFloat16Vectors.
+         *       It still works for FloatVector/BinarVector/SparseVector, don't use it for Float16Vector/BFloat16Vector.
          *
          *
          * @param vectors list of target vectors:
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float;
          *                if vector type is FloatVector, vectors is List of List Float;
@@ -245,12 +257,78 @@ public class SearchParam {
          *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
          *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
          * @return <code>Builder</code>
          * @return <code>Builder</code>
          */
          */
+        @Deprecated
         public Builder withVectors(@NonNull List<?> vectors) {
         public Builder withVectors(@NonNull List<?> vectors) {
             this.vectors = vectors;
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
             this.NQ = (long) vectors.size();
             return this;
             return this;
         }
         }
 
 
+        /**
+         * Sets the target vectors to search on FloatVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withFloatVectors(@NonNull List<List<Float>> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.FloatVector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on BinaryVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withBinaryVectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.BinaryVector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on Float16Vector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.Float16Vector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on BFloat16Vector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withBFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.BFloat16Vector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on SparseFloatVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withSparseFloatVectors(@NonNull List<SortedMap<Long, Float>> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            plType = PlaceholderType.SparseFloatVector;
+            return this;
+        }
+
         /**
         /**
          * Specifies the decimal place of the returned results.
          * Specifies the decimal place of the returned results.
          *
          *
@@ -320,52 +398,59 @@ public class SearchParam {
                 throw new ParamException("The guarantee timestamp must be greater than 0");
                 throw new ParamException("The guarantee timestamp must be greater than 0");
             }
             }
 
 
-            if (vectors == null || vectors.isEmpty()) {
-                throw new ParamException("Target vectors can not be empty");
-            }
+            SearchParam.verifyVectors(vectors);
 
 
-            if (vectors.get(0) instanceof List) {
-                // float vectors
-                // TODO: here only check the first element, potential risk
-                List<?> first = (List<?>) vectors.get(0);
-                if (!(first.get(0) instanceof Float)) {
-                    throw new ParamException("Float vector field's value must be Lst<Float>");
-                }
+            return new SearchParam(this);
+        }
+    }
+
+    public static void verifyVectors(List<?> vectors) {
+        if (vectors == null || vectors.isEmpty()) {
+            throw new ParamException("Target vectors can not be empty");
+        }
 
 
-                int dim = first.size();
-                for (int i = 1; i < vectors.size(); ++i) {
-                    List<?> temp = (List<?>) vectors.get(i);
-                    if (dim != temp.size()) {
-                        throw new ParamException("Target vector dimension must be equal");
-                    }
+        if (vectors.get(0) instanceof List) {
+            // FloatVector
+            // TODO: here only check the first element, potential risk
+            List<?> first = (List<?>) vectors.get(0);
+            if (!(first.get(0) instanceof Float)) {
+                throw new ParamException("Float vector field's value must be Lst<Float>");
+            }
+
+            int dim = first.size();
+            for (int i = 1; i < vectors.size(); ++i) {
+                List<?> temp = (List<?>) vectors.get(i);
+                if (dim != temp.size()) {
+                    throw new ParamException("Target vector dimension must be equal");
                 }
                 }
-            } else if (vectors.get(0) instanceof ByteBuffer) {
-                // binary vectors
-                // TODO: here only check the first element, potential risk
-                ByteBuffer first = (ByteBuffer) vectors.get(0);
-                int dim = first.position();
-                for (int i = 1; i < vectors.size(); ++i) {
-                    ByteBuffer temp = (ByteBuffer) vectors.get(i);
-                    if (dim != temp.position()) {
-                        throw new ParamException("Target vector dimension must be equal");
-                    }
+            }
+        } else if (vectors.get(0) instanceof ByteBuffer) {
+            // BinaryVector/Float16Vector/BFloatVector
+            // TODO: here only check the first element, potential risk
+            ByteBuffer first = (ByteBuffer) vectors.get(0);
+            int len = first.position();
+            for (int i = 1; i < vectors.size(); ++i) {
+                ByteBuffer temp = (ByteBuffer) vectors.get(i);
+                if (len != temp.position()) {
+                    throw new ParamException("Target vector dimension must be equal");
                 }
                 }
-            } else if (vectors.get(0) instanceof SortedMap) {
-                // sparse vectors, must be SortedMap<Long, Float>
-                // TODO: here only check the first element, potential risk
-                SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);
-
-
-            } else {
-                String msg = "Search target vector type is illegal." +
-                        " Only allow List<Float> for FloatVector," +
-                        " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
-                        " List<SortedMap<Long, Float>> for SparseFloatVector.";
-                throw new ParamException(msg);
             }
             }
-
-            return new SearchParam(this);
+        } else if (vectors.get(0) instanceof SortedMap) {
+            // SparseFloatVector
+            // TODO: here only check the first element, potential risk
+            SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);
+            if (!(map.firstKey() instanceof Long)) {
+                throw new ParamException("key type of SparseFloatVector must be Long");
+            }
+            if (!(map.get(map.firstKey()) instanceof Float)) {
+                throw new ParamException("Value type of SparseFloatVector must be Float");
+            }
+        } else {
+            String msg = "Search target vector type is illegal." +
+                    " Only allow List<Float> for FloatVector," +
+                    " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
+                    " List<SortedMap<Long, Float>> for SparseFloatVector.";
+            throw new ParamException(msg);
         }
         }
     }
     }
-
 }
 }

+ 1 - 1
src/main/java/io/milvus/param/index/AlterIndexParam.java

@@ -30,7 +30,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
 /**
 /**
- * Parameters for <code>alterCollection</code> interface.
+ * Parameters for <code>alterIndex</code> interface.
  */
  */
 @Getter
 @Getter
 @ToString
 @ToString

+ 4 - 4
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -69,7 +69,7 @@ class MilvusClientDockerTest {
     private static final int dimension = 128;
     private static final int dimension = 128;
 
 
     @Container
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.0-rc.1");
+    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:2.4-20240411-35f39593-amd64");
 
 
     @BeforeAll
     @BeforeAll
     public static void setUp() {
     public static void setUp() {
@@ -900,7 +900,7 @@ class MilvusClientDockerTest {
         // search on multiple vector fields
         // search on multiple vector fields
         AnnSearchParam param1 = AnnSearchParam.newBuilder()
         AnnSearchParam param1 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(floatVectorField)
                 .withVectorFieldName(floatVectorField)
-                .withVectors(generateFloatVectors(1))
+                .withFloatVectors(generateFloatVectors(1))
                 .withMetricType(MetricType.COSINE)
                 .withMetricType(MetricType.COSINE)
                 .withParams("{\"nprobe\": 32}")
                 .withParams("{\"nprobe\": 32}")
                 .withTopK(10)
                 .withTopK(10)
@@ -908,7 +908,7 @@ class MilvusClientDockerTest {
 
 
         AnnSearchParam param2 = AnnSearchParam.newBuilder()
         AnnSearchParam param2 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(binaryVectorField)
                 .withVectorFieldName(binaryVectorField)
-                .withVectors(generateBinaryVectors(1))
+                .withBinaryVectors(generateBinaryVectors(1))
                 .withMetricType(MetricType.HAMMING)
                 .withMetricType(MetricType.HAMMING)
                 .withParams("{}")
                 .withParams("{}")
                 .withTopK(5)
                 .withTopK(5)
@@ -916,7 +916,7 @@ class MilvusClientDockerTest {
 
 
         AnnSearchParam param3 = AnnSearchParam.newBuilder()
         AnnSearchParam param3 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(sparseVectorField)
                 .withVectorFieldName(sparseVectorField)
-                .withVectors(generateSparseVectors(1))
+                .withSparseFloatVectors(generateSparseVectors(1))
                 .withMetricType(MetricType.IP)
                 .withMetricType(MetricType.IP)
                 .withParams("{\"drop_ratio_search\":0.2}")
                 .withParams("{\"drop_ratio_search\":0.2}")
                 .withTopK(7)
                 .withTopK(7)

+ 2 - 2
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -2194,7 +2194,7 @@ class MilvusServiceClientTest {
                 .withVectorFieldName("field1")
                 .withVectorFieldName("field1")
                 .withMetricType(MetricType.IP)
                 .withMetricType(MetricType.IP)
                 .withTopK(5)
                 .withTopK(5)
-                .withVectors(vectors)
+                .withFloatVectors(vectors)
                 .withExpr("dummy")
                 .withExpr("dummy")
                 .build();
                 .build();
 
 
@@ -2208,7 +2208,7 @@ class MilvusServiceClientTest {
                 .withVectorFieldName("field2")
                 .withVectorFieldName("field2")
                 .withMetricType(MetricType.HAMMING)
                 .withMetricType(MetricType.HAMMING)
                 .withTopK(5)
                 .withTopK(5)
-                .withVectors(bVectors)
+                .withBinaryVectors(bVectors)
                 .withExpr("dummy")
                 .withExpr("dummy")
                 .build();
                 .build();