Browse Source

searchIterator support multi vector type (#857)

Signed-off-by: lentitude2tk <xushuang.hu@zilliz.com>
xushuang.hu 1 year ago
parent
commit
8a3bbb92af

+ 34 - 5
src/main/java/io/milvus/orm/iterator/SearchIterator.java

@@ -1,6 +1,7 @@
 package io.milvus.orm.iterator;
 
 import com.amazonaws.util.CollectionUtils;
+import com.amazonaws.util.StringUtils;
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.google.common.collect.Lists;
 import io.milvus.common.utils.ExceptionUtils;
@@ -21,10 +22,12 @@ import io.milvus.v2.utils.RpcUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.nio.ByteBuffer;
 import java.text.DecimalFormat;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.SortedMap;
 
 import static io.milvus.param.Constant.DEFAULT_SEARCH_EXTENSION_RATE;
 import static io.milvus.param.Constant.EF;
@@ -215,7 +218,7 @@ public class SearchIterator {
     }
 
     private SearchResultsWrapper executeNextSearch(Map<String, Object> params, String nextExpr, boolean toExtendBatch) {
-        SearchParam searchParam = SearchParam.newBuilder()
+        SearchParam.Builder searchParamBuilder = SearchParam.newBuilder()
                 .withDatabaseName(searchIteratorParam.getDatabaseName())
                 .withCollectionName(searchIteratorParam.getCollectionName())
                 .withPartitionNames(searchIteratorParam.getPartitionNames())
@@ -224,13 +227,16 @@ public class SearchIterator {
                 .withTopK(extendBatchSize(batchSize, toExtendBatch, params))
                 .withExpr(nextExpr)
                 .withOutFields(searchIteratorParam.getOutFields())
-                .withVectors(searchIteratorParam.getVectors())
                 .withRoundDecimal(searchIteratorParam.getRoundDecimal())
                 .withParams(JacksonUtils.toJsonString(params))
-                .withIgnoreGrowing(searchIteratorParam.isIgnoreGrowing())
-                .build();
+                .withIgnoreGrowing(searchIteratorParam.isIgnoreGrowing());
 
-        SearchRequest searchRequest = ParamUtils.convertSearchParam(searchParam);
+        if (!StringUtils.isNullOrEmpty(searchIteratorParam.getGroupByFieldName())) {
+            searchParamBuilder.withGroupByFieldName(searchIteratorParam.getGroupByFieldName());
+        }
+        fillVectorsByPlType(searchParamBuilder);
+
+        SearchRequest searchRequest = ParamUtils.convertSearchParam(searchParamBuilder.build());
         SearchResults response = blockingStub.search(searchRequest);
 
         String title = String.format("SearchRequest collectionName:%s", searchIteratorParam.getCollectionName());
@@ -239,6 +245,29 @@ public class SearchIterator {
         return new SearchResultsWrapper(response.getResults());
     }
 
+    private void fillVectorsByPlType(SearchParam.Builder searchParamBuilder) {
+        switch (searchIteratorParam.getPlType()) {
+            case FloatVector:
+                searchParamBuilder.withFloatVectors((List<List<Float>>) searchIteratorParam.getVectors());
+                break;
+            case BinaryVector:
+                searchParamBuilder.withBinaryVectors((List<ByteBuffer>) searchIteratorParam.getVectors());
+                break;
+            case Float16Vector:
+                searchParamBuilder.withFloat16Vectors((List<ByteBuffer>) searchIteratorParam.getVectors());
+                break;
+            case BFloat16Vector:
+                searchParamBuilder.withBFloat16Vectors((List<ByteBuffer>) searchIteratorParam.getVectors());
+                break;
+            case SparseFloatVector:
+                searchParamBuilder.withSparseFloatVectors((List<SortedMap<Long, Float>>) searchIteratorParam.getVectors());
+                break;
+            default:
+                searchParamBuilder.withVectors(searchIteratorParam.getVectors());
+                break;
+        }
+    }
+
     private int extendBatchSize(int batchSize, boolean toExtendBatchSize, Map<String, Object> nextParams) {
         int extendRate = 1;
 

+ 138 - 39
src/main/java/io/milvus/param/dml/SearchIteratorParam.java

@@ -22,6 +22,7 @@ package io.milvus.param.dml;
 import com.google.common.collect.Lists;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.exception.ParamException;
+import io.milvus.grpc.PlaceholderType;
 import io.milvus.param.Constant;
 import io.milvus.param.MetricType;
 import io.milvus.param.ParamUtils;
@@ -32,6 +33,7 @@ import org.jetbrains.annotations.NotNull;
 
 import java.nio.ByteBuffer;
 import java.util.List;
+import java.util.SortedMap;
 
 import static io.milvus.param.Constant.UNLIMITED;
 
@@ -58,6 +60,8 @@ public class SearchIteratorParam {
     private final Long gracefulTime;
     private final ConsistencyLevelEnum consistencyLevel;
     private final boolean ignoreGrowing;
+    private final String groupByFieldName;
+    private final PlaceholderType plType;
 
     private final long batchSize;
 
@@ -79,6 +83,9 @@ public class SearchIteratorParam {
         this.gracefulTime = builder.gracefulTime;
         this.consistencyLevel = builder.consistencyLevel;
         this.ignoreGrowing = builder.ignoreGrowing;
+        this.groupByFieldName = builder.groupByFieldName;
+        this.plType = builder.plType;
+        
         this.batchSize = builder.batchSize;
     }
 
@@ -107,6 +114,12 @@ public class SearchIteratorParam {
         private Long gracefulTime = 5000L;
         private ConsistencyLevelEnum consistencyLevel = null;
         private Boolean ignoreGrowing = Boolean.FALSE;
+        private String groupByFieldName;
+
+        // plType is used to distinct vector type
+        // for Float16Vector/BFloat16Vector and BinaryVector, user inputs ByteBuffer
+        // the sdk cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
+        private PlaceholderType plType = PlaceholderType.None;
 
         private Long batchSize = 1000L;
 
@@ -241,18 +254,89 @@ public class SearchIteratorParam {
 
         /**
          * Sets the target vectors.
+         * Note: Deprecated in v2.4.0, for the reason that the sdk cannot know a ByteBuffer
+         *       is a BinarVector or Float16Vector/BFloat16Vector.
+         *       Replaced by withFloatVectors/withBinaryVectors/withFloat16Vectors/withBFloat16Vectors/withSparseFloatVectors.
+         *       It still works for FloatVector/BinarVector/SparseVector, don't use it for Float16Vector/BFloat16Vector.
          *
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float;
          *                if vector type is BinaryVector, vectors is List of ByteBuffer;
+         *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
          * @return <code>Builder</code>
          */
+        @Deprecated
         public Builder withVectors(@NonNull List<?> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
             return this;
         }
 
+        /**
+         * Sets the target vectors to search on FloatVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withFloatVectors(@NonNull List<List<Float>> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            this.plType = PlaceholderType.FloatVector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on BinaryVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withBinaryVectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            this.plType = PlaceholderType.BinaryVector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on Float16Vector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            this.plType = PlaceholderType.Float16Vector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on BFloat16Vector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withBFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            this.plType = PlaceholderType.BFloat16Vector;
+            return this;
+        }
+
+        /**
+         * Sets the target vectors to search on SparseFloatVector field.
+         *
+         * @param vectors target vectors to search
+         * @return <code>Builder</code>
+         */
+        public Builder withSparseFloatVectors(@NonNull List<SortedMap<Long, Float>> vectors) {
+            this.vectors = vectors;
+            this.NQ = (long) vectors.size();
+            this.plType = PlaceholderType.SparseFloatVector;
+            return this;
+        }
+
         /**
          * Specifies the decimal place of the returned results.
          *
@@ -290,6 +374,17 @@ public class SearchIteratorParam {
             return this;
         }
 
+        /**
+         * Sets field name to do grouping.
+         *
+         * @param groupByFieldName field name to do grouping
+         * @return <code>Builder</code>
+         */
+        public Builder withGroupByFieldName(@NonNull String groupByFieldName) {
+            this.groupByFieldName = groupByFieldName;
+            return this;
+        }
+
         /**
          * Specify a value to control the number of entities returned per batch. Must be a positive value.
          * Default value is 1000, will return without batchSize.
@@ -323,51 +418,55 @@ public class SearchIteratorParam {
                 throw new ParamException("The guarantee timestamp must be greater than 0");
             }
 
-            if (vectors == null || vectors.isEmpty()) {
-                throw new ParamException("Target vectors can not be empty");
-            }
-
             if (metricType == MetricType.None) {
                 throw new ParamException("must specify metricType for search iterator");
             }
 
-            if (vectors.get(0) instanceof List) {
-                if (vectors.size() > 1) {
-                    throw new ParamException("Not support search iteration over multiple vectors at present");
-                }
-
-                // float vectors
-                List<?> first = (List<?>) vectors.get(0);
-                if (!(first.get(0) instanceof Float)) {
-                    throw new ParamException("Float vector field's value must be Lst<Float>");
-                }
-
-                int dim = first.size();
-                for (int i = 1; i < vectors.size(); ++i) {
-                    List<?> temp = (List<?>) vectors.get(i);
-                    if (dim != temp.size()) {
-                        throw new ParamException("Target vector dimension must be equal");
-                    }
-                }
-            } else if (vectors.get(0) instanceof ByteBuffer) {
-                // binary vectors
-                if (vectors.size() > 1) {
-                    throw new ParamException("Not support search iteration over multiple vectors at present");
-                }
-
-                ByteBuffer first = (ByteBuffer) vectors.get(0);
-                int dim = first.position();
-                for (int i = 1; i < vectors.size(); ++i) {
-                    ByteBuffer temp = (ByteBuffer) vectors.get(i);
-                    if (dim != temp.position()) {
-                        throw new ParamException("Target vector dimension must be equal");
-                    }
-                }
-            } else {
-                throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
+            verifyVectors(vectors);
+            return new SearchIteratorParam(this);
+        }
+    }
+
+    public static void verifyVectors(List<?> vectors) {
+        if (vectors == null || vectors.isEmpty()) {
+            throw new ParamException("Target vectors can not be empty");
+        }
+
+        if (vectors.get(0) instanceof List) {
+            if (vectors.size() > 1) {
+                throw new ParamException("Not support search iteration over multiple vectors at present");
             }
 
-            return new SearchIteratorParam(this);
+            // float vectors
+            List<?> first = (List<?>) vectors.get(0);
+            if (!(first.get(0) instanceof Float)) {
+                throw new ParamException("Float vector field's value must be Lst<Float>");
+            }
+        } else if (vectors.get(0) instanceof ByteBuffer) {
+            // binary vectors
+            if (vectors.size() > 1) {
+                throw new ParamException("Not support search iteration over multiple vectors at present");
+            }
+        } else if (vectors.get(0) instanceof SortedMap) {
+            // SparseFloatVector
+            if (vectors.size() > 1) {
+                throw new ParamException("Not support search iteration over multiple vectors at present");
+            }
+
+            // TODO: here only check the first element, potential risk
+            SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);
+            if (!(map.firstKey() instanceof Long)) {
+                throw new ParamException("key type of SparseFloatVector must be Long");
+            }
+            if (!(map.get(map.firstKey()) instanceof Float)) {
+                throw new ParamException("Value type of SparseFloatVector must be Float");
+            }
+        } else {
+            String msg = "Search target vector type is illegal." +
+                    " Only allow List<Float> for FloatVector," +
+                    " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
+                    " List<SortedMap<Long, Float>> for SparseFloatVector.";
+            throw new ParamException(msg);
         }
     }