Browse Source

Unify Function and Rerank (#1515)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 6 days ago
parent
commit
bec779efbf

+ 2 - 2
docker-compose.yml

@@ -32,7 +32,7 @@ services:
 
   standalone:
     container_name: milvus-javasdk-test-standalone
-    image: milvusdb/milvus:v2.6.0-rc1
+    image: milvusdb/milvus:v2.6.0
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcd:2379
@@ -77,7 +77,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-test-slave-standalone
-    image: milvusdb/milvus:v2.6.0-rc1
+    image: milvusdb/milvus:v2.6.0
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcdslave:2379

+ 1 - 1
examples/src/main/java/io/milvus/v2/HybridSearchExample.java

@@ -213,7 +213,7 @@ public class HybridSearchExample {
         HybridSearchReq hybridSearchReq = HybridSearchReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .searchRequests(searchRequests)
-                .ranker(new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f)))
+                .ranker(WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build())
                 .limit(5)
                 .consistencyLevel(ConsistencyLevel.BOUNDED)
                 .build();

+ 2 - 3
sdk-core/src/main/java/io/milvus/v2/service/vector/request/HybridSearchReq.java

@@ -20,8 +20,7 @@
 package io.milvus.v2.service.vector.request;
 
 import io.milvus.v2.common.ConsistencyLevel;
-import io.milvus.v2.service.collection.request.LoadCollectionReq;
-import io.milvus.v2.service.vector.request.ranker.BaseRanker;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import lombok.Builder;
 import lombok.Data;
 import lombok.experimental.SuperBuilder;
@@ -36,7 +35,7 @@ public class HybridSearchReq
     private String collectionName;
     private List<String> partitionNames;
     private List<AnnSearchReq> searchRequests;
-    private BaseRanker ranker;
+    private CreateCollectionReq.Function ranker;
     @Builder.Default
     @Deprecated
     private int topK = 0; // deprecated, replaced by "limit"

+ 2 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java

@@ -21,6 +21,7 @@ package io.milvus.v2.service.vector.request;
 
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.vector.request.data.BaseVector;
 
 import lombok.Builder;
@@ -66,6 +67,7 @@ public class SearchReq {
     private String groupByFieldName;
     private Integer groupSize;
     private Boolean strictGroupSize;
+    private CreateCollectionReq.Function ranker;
 
     // Expression template, to improve expression parsing performance in complicated list
     // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]

+ 0 - 26
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/BaseRanker.java

@@ -1,26 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package io.milvus.v2.service.vector.request.ranker;
-
-import java.util.Map;
-
-public abstract class BaseRanker {
-    public abstract Map<String, String> getProperties();
-}

+ 52 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/DecayRanker.java

@@ -0,0 +1,52 @@
+package io.milvus.v2.service.vector.request.ranker;
+
+import io.milvus.common.clientenum.FunctionType;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import lombok.Builder;
+import lombok.experimental.SuperBuilder;
+
+import java.util.Map;
+
+/**
+ * The Decay reranking strategy, which by adjusting search rankings based on numeric field values.
+ * Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md
+ *
+ * You also can declare a decay ranker by Function
+ * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
+ *                 .functionType(FunctionType.RERANK)
+ *                 .name("time_decay")
+ *                 .description("time decay")
+ *                 .inputFieldNames(Collections.singletonList("timestamp"))
+ *                 .param("reranker", "decay")
+ *                 .param("function", "gauss")
+ *                 .param("origin", "1000")
+ *                 .param("scale", "10000")
+ *                 .param("offset", "24")
+ *                 .param("decay", "0.5")
+ *                 .build();
+ */
+@SuperBuilder
+public class DecayRanker extends CreateCollectionReq.Function {
+    @Builder.Default
+    private String function = "gauss";
+    private Number origin;
+    private Number scale;
+
+    public FunctionType getFunctionType() {
+        return FunctionType.RERANK;
+    }
+
+    public Map<String, String> getParams() {
+        // the parent params might contain "offset" and "decay"
+        Map<String, String> props = super.getParams();
+        props.put("reranker", "decay");
+        props.put("function", function); // "gauss", "exp", or "linear"
+        if (origin != null) {
+            props.put("origin", origin.toString());
+        }
+        if (scale != null) {
+            props.put("scale", scale.toString());
+        }
+        return props;
+    }
+}

+ 55 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/ModelRanker.java

@@ -0,0 +1,55 @@
+package io.milvus.v2.service.vector.request.ranker;
+
+import com.google.gson.JsonArray;
+import io.milvus.common.clientenum.FunctionType;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import lombok.Builder;
+import lombok.experimental.SuperBuilder;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * The Model reranking strategy, which transforms Milvus search by integrating advanced language models
+ * that understand semantic relationships between queries and documents.
+ * Read the doc for more info: https://milvus.io/docs/model-ranker-overview.md
+ *
+ * You also can declare a model ranker by Function
+ * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
+ *                 .functionType(FunctionType.RERANK)
+ *                 .name("semantic_ranker")
+ *                 .description("semantic ranker")
+ *                 .inputFieldNames(Collections.singletonList("document"))
+ *                 .param("reranker", "model")
+ *                 .param("provider", "tei")
+ *                 .param("queries", "[\"machine learning for time series\"]")
+ *                 .param("endpoint", "http://model-service:8080")
+ *                 .build();
+ */
+@SuperBuilder
+public class ModelRanker extends CreateCollectionReq.Function {
+    @Builder.Default
+    private String provider = "tei";
+    @Builder.Default
+    private List<String> queries = new ArrayList<>();
+    private String endpoint;
+
+    public FunctionType getFunctionType() {
+        return FunctionType.RERANK;
+    }
+
+    public Map<String, String> getParams() {
+        // the parent params might contain "offset" and "decay"
+        Map<String, String> props = super.getParams();
+        props.put("reranker", "model");
+        props.put("provider", provider); // "tei" or "vllm"
+        JsonArray json = new JsonArray();
+        queries.forEach(json::add);
+        props.put("queries", json.toString());
+        if (endpoint != null) {
+            props.put("endpoint", endpoint);
+        }
+        return props;
+    }
+}

+ 29 - 6
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/RRFRanker.java

@@ -20,26 +20,49 @@
 package io.milvus.v2.service.vector.request.ranker;
 
 import com.google.gson.JsonObject;
+import io.milvus.common.clientenum.FunctionType;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import lombok.Builder;
+import lombok.experimental.SuperBuilder;
 
 import java.util.HashMap;
 import java.util.Map;
 
 /**
  * The RRF reranking strategy, which merges results from multiple searches, favoring items that consistently appear.
+ * Read the doc for more info: https://milvus.io/docs/rrf-ranker.md
+ *
+ * Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker
+ * https://milvus.io/docs/decay-ranker-overview.md
+ * https://milvus.io/docs/model-ranker-overview.md
+ * So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with
+ * RRFRanker/WeightedRanker in some users client code. We will mention it in release note.
+ *  * In old client code, to declare a WeightedRanker:
+ *  *   RRFRanker ranker = new RRFRanker(20)
+ *  * After this change, the client code should be changed accordingly:
+ *  *   RRFRanker ranker = RRFRanker.builder().k(20).build()
+ *
+ * You also can declare a rrf ranker by Function
+ * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
+ *                 .functionType(FunctionType.RERANK)
+ *                 .param("strategy", "rrf")
+ *                 .param("params", "{\"k\": 60}")
+ *                 .build();
  */
-public class RRFRanker extends BaseRanker {
+@SuperBuilder
+public class RRFRanker extends CreateCollectionReq.Function {
+    @Builder.Default
     private int k = 60;
 
-    public RRFRanker(int k) {
-        this.k = k;
+    public FunctionType getFunctionType() {
+        return FunctionType.RERANK;
     }
 
-    @Override
-    public Map<String, String> getProperties() {
+    public Map<String, String> getParams() {
         JsonObject params = new JsonObject();
         params.addProperty("k", this.k);
 
-        Map<String, String> props = new HashMap<>();
+        Map<String, String> props = super.getParams();
         props.put("strategy", "rrf");
         props.put("params", params.toString());
         return props;

+ 31 - 7
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/WeightedRanker.java

@@ -20,8 +20,13 @@
 package io.milvus.v2.service.vector.request.ranker;
 
 import com.google.gson.JsonObject;
+import io.milvus.common.clientenum.FunctionType;
 import io.milvus.common.utils.JsonUtils;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import lombok.Builder;
+import lombok.experimental.SuperBuilder;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -29,20 +34,39 @@ import java.util.Map;
 /**
  * The Average Weighted Scoring reranking strategy, which prioritizes vectors based on relevance,
  * averaging their significance.
+ * Read the doc for more info: https://milvus.io/docs/weighted-ranker.md
+ *
+ * Note: In v2.6, the Function and Rerank have been unified to support more rerank types: decay and model ranker
+ * https://milvus.io/docs/decay-ranker-overview.md
+ * https://milvus.io/docs/model-ranker-overview.md
+ * So we have to inherit the BaseRanker from Function, this change will lead to uncomfortable issues with
+ * RRFRanker/WeightedRanker in some users client code. We will mention it in release note.
+ * In old client code, to declare a WeightedRanker:
+ *   WeightedRanker ranker = new WeightedRanker(Arrays.asList(0.2f, 0.5f, 0.6f))
+ * After this change, the client code should be changed accordingly:
+ *   WeightedRanker ranker = WeightedRanker.builder().weights(Arrays.asList(0.2f, 0.5f, 0.6f)).build()
+ *
+ * You also can declare a weighter ranker by Function
+ * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
+ *                 .functionType(FunctionType.RERANK)
+ *                 .param("strategy", "weighted")
+ *                 .param("params", "{\"weights\": [0.4, 0.6]}")
+ *                 .build();
  */
-public class WeightedRanker extends BaseRanker {
-    private List<Float> weights;
+@SuperBuilder
+public class WeightedRanker extends CreateCollectionReq.Function {
+    @Builder.Default
+    private List<Float> weights = new ArrayList<>();
 
-    public WeightedRanker(List<Float> weights) {
-        this.weights = weights;
+    public FunctionType getFunctionType() {
+        return FunctionType.RERANK;
     }
 
-    @Override
-    public Map<String, String> getProperties() {
+    public Map<String, String> getParams() {
         JsonObject params = new JsonObject();
         params.add("weights", JsonUtils.toJsonTree(this.weights).getAsJsonArray());
 
-        Map<String, String> props = new HashMap<>();
+        Map<String, String> props = super.getParams();
         props.put("strategy", "weighted");
         props.put("params", params.toString());
         return props;

+ 33 - 5
sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java

@@ -31,9 +31,11 @@ import io.milvus.param.Constant;
 import io.milvus.param.ParamUtils;
 import io.milvus.v2.exception.ErrorCode;
 import io.milvus.v2.exception.MilvusClientException;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.vector.request.*;
-import io.milvus.v2.service.vector.request.ranker.BaseRanker;
 import io.milvus.v2.service.vector.request.data.*;
+import io.milvus.v2.service.vector.request.ranker.RRFRanker;
+import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
 import lombok.NonNull;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.lang3.StringUtils;
@@ -279,6 +281,12 @@ public class VectorUtils {
             builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode());
         }
 
+        // set ranker, support reranking search result from v2.6.1
+        CreateCollectionReq.Function ranker = request.getRanker();
+        if (ranker != null) {
+            builder.setFunctionScore(convertFunctionScore(ranker));
+        }
+
         return builder.build();
     }
 
@@ -473,16 +481,25 @@ public class VectorUtils {
         }
 
         // set ranker
-        BaseRanker ranker = request.getRanker();
-        if (request.getRanker() == null) {
+        CreateCollectionReq.Function ranker = request.getRanker();
+        if (ranker == null) {
             throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
         }
 
-        // topK value is deprecated, always use "limit" to set the topK
-        Map<String, String> props = ranker.getProperties();
+        Map<String, String> props = new HashMap<>();
         props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
         props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
         props.put(Constant.OFFSET, String.valueOf(request.getOffset()));
+
+        if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
+            // old logic for RRF/Weighted ranker
+            Map<String, String> params = ranker.getParams();
+            props.putAll(params);
+        } else {
+            // new logic for Decay/Model ranker
+            builder.setFunctionScore(convertFunctionScore(ranker));
+        }
+
         List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
         if (CollectionUtils.isNotEmpty(propertiesList)) {
             propertiesList.forEach(builder::addRankParams);
@@ -528,6 +545,17 @@ public class VectorUtils {
         return builder.build();
     }
 
+    private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
+        FunctionSchema schema = FunctionSchema.newBuilder()
+                .setName(function.getName())
+                .setDescription(function.getDescription())
+                .setType(FunctionType.forNumber(function.getFunctionType().getCode()))
+                .addAllInputFieldNames(function.getInputFieldNames())
+                .addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
+                .build();
+        return FunctionScore.newBuilder().addFunctions(schema).build();
+    }
+
     public String getExprById(String primaryFieldName, List<?> ids) {
         StringBuilder sb = new StringBuilder();
         sb.append(primaryFieldName).append(" in [");

+ 1 - 1
sdk-core/src/test/java/io/milvus/TestUtils.java

@@ -11,7 +11,7 @@ public class TestUtils {
     private int dimension = 256;
     private static final Random RANDOM = new Random();
 
-    public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0-rc1";
+    public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.0";
 
     public TestUtils(int dimension) {
         this.dimension = dimension;

+ 2 - 2
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1047,7 +1047,7 @@ class MilvusClientV2DockerTest {
                     return HybridSearchReq.builder()
                             .collectionName(randomCollectionName)
                             .searchRequests(searchRequests)
-                            .ranker(new RRFRanker(20))
+                            .ranker(RRFRanker.builder().k(20).build())
                             .limit(topk)
                             .consistencyLevel(ConsistencyLevel.BOUNDED)
                             .build();
@@ -2853,7 +2853,7 @@ class MilvusClientV2DockerTest {
                                     .databaseName(dbName)
                                     .collectionName(randomCollectionName)
                                     .searchRequests(Collections.singletonList(subReq))
-                                    .ranker(new RRFRanker(20))
+                                    .ranker(RRFRanker.builder().k(20).build())
                                     .limit(5)
                                     .build());
                             List<List<SearchResp.SearchResult>> oneResult = searchResp.getSearchResults();