Pārlūkot izejas kodu

Support FunctionScore, multi-reranker for search/hybridSearch (#1587)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 mēnesi atpakaļ
vecāks
revīzija
36947497cd

+ 4 - 2
examples/src/main/java/io/milvus/v2/HybridSearchExample.java

@@ -31,6 +31,7 @@ import io.milvus.v2.service.collection.request.AddFieldReq;
 import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.collection.request.DropCollectionReq;
 import io.milvus.v2.service.vector.request.AnnSearchReq;
+import io.milvus.v2.service.vector.request.FunctionScore;
 import io.milvus.v2.service.vector.request.HybridSearchReq;
 import io.milvus.v2.service.vector.request.InsertReq;
 import io.milvus.v2.service.vector.request.QueryReq;
@@ -122,7 +123,6 @@ public class HybridSearchExample {
                 .metricType(BINARY_VECTOR_METRIC)
                 .build());
         Map<String,Object> fv16Params = new HashMap<>();
-        fv16Params.clear();
         fv16Params.put("M",16);
         fv16Params.put("efConstruction",64);
         indexes.add(IndexParam.builder()
@@ -212,7 +212,9 @@ public class HybridSearchExample {
         HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .searchRequests(searchRequests)
-                .ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
+                .functionScore(FunctionScore.builder()
+                        .addFunction(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
+                        .build())
                 .limit(5)
                 .consistencyLevel(ConsistencyLevel.BOUNDED)
                 .build();

+ 3 - 1
sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java

@@ -210,8 +210,10 @@ public class QueryIterator {
     private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) {
         // for seeking offset, no need to return output fields
         List<String> outputFields = new ArrayList<>();
+        boolean reduceStopForBest = queryIteratorParam.isReduceStopForBest();
         if (!isSeek) {
             outputFields = queryIteratorParam.getOutFields();
+            reduceStopForBest = false;
         }
         QueryParam queryParam = QueryParam.newBuilder()
                 .withDatabaseName(queryIteratorParam.getDatabaseName())
@@ -230,7 +232,7 @@ public class QueryIterator {
         // reduce stop for best
         builder.addQueryParams(KeyValuePair.newBuilder()
                 .setKey(Constant.REDUCE_STOP_FOR_BEST)
-                .setValue(String.valueOf(queryIteratorParam.isReduceStopForBest()))
+                .setValue(String.valueOf(reduceStopForBest))
                 .build());
 
         // iterator

+ 2 - 1
sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java

@@ -194,7 +194,8 @@ public class CreateCollectionReq {
     @Data
     @SuperBuilder
     public static class Function {
-        private String name;
+        @Builder.Default
+        private String name = "";
         @Builder.Default
         private String description = "";
         @Builder.Default

+ 50 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/request/FunctionScore.java

@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v2.service.vector.request;
+
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import lombok.Builder;
+import lombok.Data;
+import lombok.experimental.SuperBuilder;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@Data
+@SuperBuilder
+public class FunctionScore {
+    @Builder.Default
+    private List<CreateCollectionReq.Function> functions = new ArrayList<>();
+    @Builder.Default
+    private Map<String, String> params = new HashMap<>();
+
+    public static abstract class FunctionScoreBuilder<C extends FunctionScore, B extends FunctionScore.FunctionScoreBuilder<C, B>> {
+        public B addFunction(CreateCollectionReq.Function func) {
+            if(null == this.functions$value ){
+                this.functions$value = new ArrayList<>();
+            }
+            this.functions$value.add(func);
+            this.functions$set = true;
+            return self();
+        }
+    }
+}

+ 5 - 1
sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java

@@ -35,7 +35,6 @@ public class HybridSearchReq
     private String collectionName;
     private List<String> partitionNames;
     private List<AnnSearchReq> searchRequests;
-    private CreateCollectionReq.Function ranker;
     @Builder.Default
     @Deprecated
     private int topK = 0; // deprecated, replaced by "limit"
@@ -51,6 +50,11 @@ public class HybridSearchReq
     private String groupByFieldName;
     private Integer groupSize;
     private Boolean strictGroupSize;
+    @Deprecated
+    private CreateCollectionReq.Function ranker;
+    // milvus v2.6.1 supports multi-rankers. The "ranker" still works. It is recommended
+    // to use functionScore even you have only one ranker. Not allow to set both.
+    private FunctionScore functionScore;
 
     public static abstract class HybridSearchReqBuilder<C extends HybridSearchReq, B extends HybridSearchReq.HybridSearchReqBuilder<C, B>> {
         // topK is deprecated, topK and limit must be the same value

+ 4 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java

@@ -67,7 +67,11 @@ public class SearchReq {
     private String groupByFieldName;
     private Integer groupSize;
     private Boolean strictGroupSize;
+    @Deprecated
     private CreateCollectionReq.Function ranker;
+    // milvus v2.6.1 supports multi-rankers. The "ranker" still works. It is recommended
+    // to use functionScore even you have only one ranker. Not allow to set both.
+    private FunctionScore functionScore;
 
     // Expression template, to improve expression parsing performance in complicated list
     // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]

+ 59 - 19
sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java

@@ -20,6 +20,8 @@
 package io.milvus.v2.utils;
 
 import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
 import com.google.gson.reflect.TypeToken;
 import com.google.protobuf.ByteString;
 import io.milvus.common.utils.GTsDict;
@@ -33,6 +35,7 @@ import io.milvus.v2.exception.ErrorCode;
 import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.vector.request.*;
+import io.milvus.v2.service.vector.request.FunctionScore;
 import io.milvus.v2.service.vector.request.data.*;
 import io.milvus.v2.service.vector.request.ranker.RRFRanker;
 import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
@@ -283,8 +286,14 @@ public class VectorUtils {
 
         // set ranker, support reranking search result from v2.6.1
         CreateCollectionReq.Function ranker = request.getRanker();
-        if (ranker != null) {
-            builder.setFunctionScore(convertFunctionScore(ranker));
+        io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore();
+        if (ranker != null && functionScore != null) {
+            throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore.");
+        }
+        if (functionScore != null) {
+            builder.setFunctionScore(convertFunctionScore(functionScore));
+        } else if (ranker != null) {
+            builder.setFunctionScore(convertOneFunction(ranker));
         }
 
         return builder.build();
@@ -480,24 +489,28 @@ public class VectorUtils {
             builder.addRequests(searchRequest);
         }
 
-        // set ranker
-        CreateCollectionReq.Function ranker = request.getRanker();
-        if (ranker == null) {
-            throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
-        }
-
         Map<String, String> props = new HashMap<>();
         props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
         props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
         props.put(Constant.OFFSET, String.valueOf(request.getOffset()));
 
-        if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
-            // old logic for RRF/Weighted ranker
-            Map<String, String> params = ranker.getParams();
-            props.putAll(params);
-        } else {
-            // new logic for Decay/Model ranker
-            builder.setFunctionScore(convertFunctionScore(ranker));
+        // set ranker
+        CreateCollectionReq.Function ranker = request.getRanker();
+        io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore();
+        if (ranker != null && functionScore != null) {
+            throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore.");
+        }
+        if (functionScore != null) {
+            builder.setFunctionScore(convertFunctionScore(functionScore));
+        } else if (ranker != null) {
+            if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
+                // old logic for RRF/Weighted ranker
+                Map<String, String> params = ranker.getParams();
+                props.putAll(params);
+            } else {
+                // new logic for Decay/Model ranker
+                builder.setFunctionScore(convertOneFunction(ranker));
+            }
         }
 
         List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
@@ -545,15 +558,42 @@ public class VectorUtils {
         return builder.build();
     }
 
-    private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
-        FunctionSchema schema = FunctionSchema.newBuilder()
+    private FunctionSchema convertFunctionSchema(CreateCollectionReq.Function function) {
+        Map<String, String> params = function.getParams();
+        // FunctionSchema type keyword is "reranker", old RRF/Weighted ranker type keyword is "strategy"
+        // FunctionSchema parameters are flat, old RRF/Weighted parameters are wrapped by "params"
+        if (function instanceof RRFRanker || function instanceof WeightedRanker) {
+            String name = (function instanceof RRFRanker) ? "rrf" : "weighted";
+            params.put("reranker", name);
+            JsonObject inner = JsonParser.parseString(params.get("params")).getAsJsonObject();
+            for (String key : inner.keySet()) {
+                params.put(key, inner.get(key).toString());
+            }
+            params.remove("strategy");
+            params.remove("params");
+        }
+        return FunctionSchema.newBuilder()
                 .setName(function.getName())
                 .setDescription(function.getDescription())
                 .setType(FunctionType.forNumber(function.getFunctionType().getCode()))
                 .addAllInputFieldNames(function.getInputFieldNames())
-                .addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
+                .addAllParams(ParamUtils.AssembleKvPair(params))
                 .build();
-        return FunctionScore.newBuilder().addFunctions(schema).build();
+    }
+
+    private io.milvus.grpc.FunctionScore convertOneFunction(CreateCollectionReq.Function function) {
+        FunctionSchema schema = convertFunctionSchema(function);
+        return io.milvus.grpc.FunctionScore.newBuilder().addFunctions(schema).build();
+    }
+
+    private io.milvus.grpc.FunctionScore convertFunctionScore(FunctionScore functionScore) {
+        io.milvus.grpc.FunctionScore.Builder builder = io.milvus.grpc.FunctionScore.newBuilder();
+        for (CreateCollectionReq.Function function : functionScore.getFunctions()) {
+            FunctionSchema schema = convertFunctionSchema(function);
+            builder.addFunctions(schema);
+        }
+        builder.addAllParams(ParamUtils.AssembleKvPair(functionScore.getParams()));
+        return builder.build();
     }
 
     public String getExprById(String primaryFieldName, List<?> ids) {

+ 32 - 13
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1074,8 +1074,8 @@ class MilvusClientV2DockerTest {
         // prepare sub requests
         int nq = 5;
         int topk = 10;
-        Function<Integer, HybridSearchReq> genRequestFunc =
-                sparseCount -> {
+        Function<Map<String, Object>, HybridSearchReq> genRequestFunc =
+                config -> {
                     List<BaseVector> floatVectors = new ArrayList<>();
                     List<BaseVector> binaryVectors = new ArrayList<>();
                     List<BaseVector> sparseVectors = new ArrayList<>();
@@ -1083,6 +1083,7 @@ class MilvusClientV2DockerTest {
                         floatVectors.add(new FloatVec(utils.generateFloatVector()));
                         binaryVectors.add(new BinaryVec(utils.generateBinaryVector()));
                     }
+                    int sparseCount = (Integer)config.get("sparseCount");
                     for (int i = 0; i < sparseCount; i++) {
                         sparseVectors.add(new SparseFloatVec(utils.generateSparseVector()));
                     }
@@ -1105,23 +1106,40 @@ class MilvusClientV2DockerTest {
                             .limit(7)
                             .build());
 
-                    return HybridSearchReq.builder()
-                            .collectionName(randomCollectionName)
-                            .searchRequests(searchRequests)
-                            .ranker(RRFRanker.builder().k(20).build())
-                            .limit(topk)
-                            .consistencyLevel(ConsistencyLevel.BOUNDED)
-                            .build();
+                    CreateCollectionReq.Function ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build();
+                    boolean useFunctionScore = (Boolean)config.get("useFunctionScore");
+                    if (useFunctionScore) {
+                        return HybridSearchReq.builder()
+                                .collectionName(randomCollectionName)
+                                .searchRequests(searchRequests)
+                                .functionScore(FunctionScore.builder().addFunction(ranker).build())
+                                .limit(topk)
+                                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                                .build();
+                    } else {
+                        return HybridSearchReq.builder()
+                                .collectionName(randomCollectionName)
+                                .searchRequests(searchRequests)
+                                .ranker(RRFRanker.builder().k(20).build())
+                                .limit(topk)
+                                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                                .build();
+                    }
         };
 
+        Map<String, Object> config = new HashMap<>();
+        config.put("sparseCount", 0);
+        config.put("useFunctionScore", false);
         // search with an empty nq, return error
-        Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(0)));
+        Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(config)));
 
         // unequal nq, return error
-        Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(1)));
+        config.put("sparseCount", 1);
+        Assertions.assertThrows(MilvusClientException.class, ()->client.hybridSearch(genRequestFunc.apply(config)));
 
         // search on empty collection, no result returned
-        SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(nq));
+        config.put("sparseCount", nq);
+        SearchResp searchResp = client.hybridSearch(genRequestFunc.apply(config));
         List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
         Assertions.assertEquals(nq, searchResults.size());
         for (List<SearchResp.SearchResult> result : searchResults) {
@@ -1142,7 +1160,8 @@ class MilvusClientV2DockerTest {
         Assertions.assertEquals(count, rowCount);
 
         // search again, there are results
-        searchResp = client.hybridSearch(genRequestFunc.apply(nq));
+        config.put("useFunctionScore", true);
+        searchResp = client.hybridSearch(genRequestFunc.apply(config));
         searchResults = searchResp.getSearchResults();
         Assertions.assertEquals(nq, searchResults.size());
         for (int i = 0; i < nq; i++) {