Browse Source

Add group by field for hybridsearch (#1159)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 8 months ago
parent
commit
a0f2f7c1b9

+ 8 - 0
src/main/java/io/milvus/param/ParamUtils.java

@@ -850,6 +850,14 @@ public class ParamUtils {
             builder.addRequests(searchRequest);
         }
 
+        if (!StringUtils.isEmpty(requestParam.getGroupByFieldName())) {
+            builder.addRankParams(
+                    KeyValuePair.newBuilder()
+                            .setKey(Constant.GROUP_BY_FIELD)
+                            .setValue(requestParam.getGroupByFieldName())
+                            .build());
+        }
+
         // set ranker
         BaseRanker ranker = requestParam.getRanker();
         Map<String, String> props = ranker.getProperties();

+ 15 - 4
src/main/java/io/milvus/param/dml/HybridSearchParam.java

@@ -22,8 +22,6 @@ package io.milvus.param.dml;
 import com.google.common.collect.Lists;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.exception.ParamException;
-import io.milvus.param.Constant;
-import io.milvus.param.MetricType;
 import io.milvus.param.ParamUtils;
 
 import io.milvus.param.dml.ranker.BaseRanker;
@@ -31,9 +29,7 @@ import lombok.Getter;
 import lombok.NonNull;
 import lombok.ToString;
 
-import java.nio.ByteBuffer;
 import java.util.List;
-import java.util.SortedMap;
 
 /**
  * Parameters for <code>search</code> interface.
@@ -51,6 +47,8 @@ public class HybridSearchParam {
     private final int roundDecimal;
     private final ConsistencyLevelEnum consistencyLevel;
 
+    private final String groupByFieldName;
+
     private HybridSearchParam(@NonNull Builder builder) {
         this.databaseName = builder.databaseName;
         this.collectionName = builder.collectionName;
@@ -61,6 +59,7 @@ public class HybridSearchParam {
         this.outFields = builder.outFields;
         this.roundDecimal = builder.roundDecimal;
         this.consistencyLevel = builder.consistencyLevel;
+        this.groupByFieldName = builder.groupByFieldName;
     }
 
     public static Builder newBuilder() {
@@ -80,6 +79,7 @@ public class HybridSearchParam {
         private final List<String> outFields = Lists.newArrayList();
         private Integer roundDecimal = -1;
         private ConsistencyLevelEnum consistencyLevel = null;
+        private String groupByFieldName = null;
 
         Builder() {
         }
@@ -209,6 +209,17 @@ public class HybridSearchParam {
             return this;
         }
 
+        /**
+         * Groups the results by a scalar field name.
+         *
+         * @param fieldName a scalar field name
+         * @return <code>Builder</code>
+         */
+        public Builder withGroupByFieldName(@NonNull String groupByFieldName) {
+            this.groupByFieldName = groupByFieldName;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link HybridSearchParam} instance.
          *

+ 4 - 0
src/main/java/io/milvus/v2/service/vector/request/AnnSearchReq.java

@@ -19,6 +19,7 @@
 
 package io.milvus.v2.service.vector.request;
 
+import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.service.vector.request.data.BaseVector;
 import lombok.Builder;
 import lombok.Data;
@@ -35,4 +36,7 @@ public class AnnSearchReq {
     private String expr = "";
     private List<BaseVector> vectors;
     private String params;
+
+    @Builder.Default
+    private IndexParam.MetricType metricType = null;
 }

+ 2 - 0
src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java

@@ -42,4 +42,6 @@ public class HybridSearchReq
     private int roundDecimal = -1;
     @Builder.Default
     private ConsistencyLevel consistencyLevel = null;
+
+    private String groupByFieldName;
 }

+ 2 - 0
src/main/java/io/milvus/v2/service/vector/request/SearchReq.java

@@ -20,6 +20,7 @@
 package io.milvus.v2.service.vector.request;
 
 import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.service.vector.request.data.BaseVector;
 import lombok.Builder;
 import lombok.Data;
@@ -38,6 +39,7 @@ public class SearchReq {
     private List<String> partitionNames = new ArrayList<>();
     @Builder.Default
     private String annsField = "";
+    private IndexParam.MetricType metricType;
     private int topK;
     private String filter;
     @Builder.Default

+ 22 - 0
src/main/java/io/milvus/v2/utils/VectorUtils.java

@@ -156,6 +156,13 @@ public class VectorUtils {
                         .setValue(String.valueOf(request.getOffset()))
                         .build());
 
+        if (null != request.getMetricType()) {
+            builder.addSearchParams(
+                    KeyValuePair.newBuilder()
+                            .setKey(Constant.METRIC_TYPE)
+                            .setValue(request.getMetricType().name())
+                            .build());
+        }
 
         if (null != request.getSearchParams()) {
             try {
@@ -287,6 +294,13 @@ public class VectorUtils {
                                 .setKey(Constant.TOP_K)
                                 .setValue(String.valueOf(annSearchReq.getTopK()))
                                 .build());
+        if (annSearchReq.getMetricType() != null) {
+            builder.addSearchParams(
+                    KeyValuePair.newBuilder()
+                            .setKey(Constant.METRIC_TYPE)
+                            .setValue(annSearchReq.getMetricType().name())
+                            .build());
+        }
 
         // params
         String params = "{}";
@@ -347,6 +361,14 @@ public class VectorUtils {
             propertiesList.forEach(builder::addRankParams);
         }
 
+        if (request.getGroupByFieldName() != null && !request.getGroupByFieldName().isEmpty()) {
+            builder.addRankParams(
+                    KeyValuePair.newBuilder()
+                            .setKey(Constant.GROUP_BY_FIELD)
+                            .setValue(request.getGroupByFieldName())
+                            .build());
+        }
+
         // output fields
         if (request.getOutFields() != null && !request.getOutFields().isEmpty()) {
             request.getOutFields().forEach(builder::addOutputFields);