|
@@ -20,6 +20,8 @@
|
|
package io.milvus.v2.utils;
|
|
package io.milvus.v2.utils;
|
|
|
|
|
|
import com.google.gson.JsonElement;
|
|
import com.google.gson.JsonElement;
|
|
|
|
+import com.google.gson.JsonObject;
|
|
|
|
+import com.google.gson.JsonParser;
|
|
import com.google.gson.reflect.TypeToken;
|
|
import com.google.gson.reflect.TypeToken;
|
|
import com.google.protobuf.ByteString;
|
|
import com.google.protobuf.ByteString;
|
|
import io.milvus.common.utils.GTsDict;
|
|
import io.milvus.common.utils.GTsDict;
|
|
@@ -33,6 +35,7 @@ import io.milvus.v2.exception.ErrorCode;
|
|
import io.milvus.v2.exception.MilvusClientException;
|
|
import io.milvus.v2.exception.MilvusClientException;
|
|
import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
|
import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
|
import io.milvus.v2.service.vector.request.*;
|
|
import io.milvus.v2.service.vector.request.*;
|
|
|
|
+import io.milvus.v2.service.vector.request.FunctionScore;
|
|
import io.milvus.v2.service.vector.request.data.*;
|
|
import io.milvus.v2.service.vector.request.data.*;
|
|
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
|
|
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
|
|
import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
|
|
import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
|
|
@@ -283,8 +286,14 @@ public class VectorUtils {
|
|
|
|
|
|
// set ranker, support reranking search result from v2.6.1
|
|
// set ranker, support reranking search result from v2.6.1
|
|
CreateCollectionReq.Function ranker = request.getRanker();
|
|
CreateCollectionReq.Function ranker = request.getRanker();
|
|
- if (ranker != null) {
|
|
|
|
- builder.setFunctionScore(convertFunctionScore(ranker));
|
|
|
|
|
|
+ io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore();
|
|
|
|
+ if (ranker != null && functionScore != null) {
|
|
|
|
+ throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore.");
|
|
|
|
+ }
|
|
|
|
+ if (functionScore != null) {
|
|
|
|
+ builder.setFunctionScore(convertFunctionScore(functionScore));
|
|
|
|
+ } else if (ranker != null) {
|
|
|
|
+ builder.setFunctionScore(convertOneFunction(ranker));
|
|
}
|
|
}
|
|
|
|
|
|
return builder.build();
|
|
return builder.build();
|
|
@@ -480,24 +489,28 @@ public class VectorUtils {
|
|
builder.addRequests(searchRequest);
|
|
builder.addRequests(searchRequest);
|
|
}
|
|
}
|
|
|
|
|
|
- // set ranker
|
|
|
|
- CreateCollectionReq.Function ranker = request.getRanker();
|
|
|
|
- if (ranker == null) {
|
|
|
|
- throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
Map<String, String> props = new HashMap<>();
|
|
Map<String, String> props = new HashMap<>();
|
|
props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
|
|
props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
|
|
props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
|
|
props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
|
|
props.put(Constant.OFFSET, String.valueOf(request.getOffset()));
|
|
props.put(Constant.OFFSET, String.valueOf(request.getOffset()));
|
|
|
|
|
|
- if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
|
|
|
|
- // old logic for RRF/Weighted ranker
|
|
|
|
- Map<String, String> params = ranker.getParams();
|
|
|
|
- props.putAll(params);
|
|
|
|
- } else {
|
|
|
|
- // new logic for Decay/Model ranker
|
|
|
|
- builder.setFunctionScore(convertFunctionScore(ranker));
|
|
|
|
|
|
+ // set ranker
|
|
|
|
+ CreateCollectionReq.Function ranker = request.getRanker();
|
|
|
|
+ io.milvus.v2.service.vector.request.FunctionScore functionScore = request.getFunctionScore();
|
|
|
|
+ if (ranker != null && functionScore != null) {
|
|
|
|
+ throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Not allow to set both ranker and functionScore.");
|
|
|
|
+ }
|
|
|
|
+ if (functionScore != null) {
|
|
|
|
+ builder.setFunctionScore(convertFunctionScore(functionScore));
|
|
|
|
+ } else if (ranker != null) {
|
|
|
|
+ if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
|
|
|
|
+ // old logic for RRF/Weighted ranker
|
|
|
|
+ Map<String, String> params = ranker.getParams();
|
|
|
|
+ props.putAll(params);
|
|
|
|
+ } else {
|
|
|
|
+ // new logic for Decay/Model ranker
|
|
|
|
+ builder.setFunctionScore(convertOneFunction(ranker));
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
|
|
List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
|
|
@@ -545,15 +558,42 @@ public class VectorUtils {
|
|
return builder.build();
|
|
return builder.build();
|
|
}
|
|
}
|
|
|
|
|
|
- private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
|
|
|
|
- FunctionSchema schema = FunctionSchema.newBuilder()
|
|
|
|
|
|
+ private FunctionSchema convertFunctionSchema(CreateCollectionReq.Function function) {
|
|
|
|
+ Map<String, String> params = function.getParams();
|
|
|
|
+ // FunctionSchema type keyword is "reranker", old RRF/Weighted ranker type keyword is "strategy"
|
|
|
|
+ // FunctionSchema parameters are flat, old RRF/Weighted parameters are wrapped by "params"
|
|
|
|
+ if (function instanceof RRFRanker || function instanceof WeightedRanker) {
|
|
|
|
+ String name = (function instanceof RRFRanker) ? "rrf" : "weighted";
|
|
|
|
+ params.put("reranker", name);
|
|
|
|
+ JsonObject inner = JsonParser.parseString(params.get("params")).getAsJsonObject();
|
|
|
|
+ for (String key : inner.keySet()) {
|
|
|
|
+ params.put(key, inner.get(key).toString());
|
|
|
|
+ }
|
|
|
|
+ params.remove("strategy");
|
|
|
|
+ params.remove("params");
|
|
|
|
+ }
|
|
|
|
+ return FunctionSchema.newBuilder()
|
|
.setName(function.getName())
|
|
.setName(function.getName())
|
|
.setDescription(function.getDescription())
|
|
.setDescription(function.getDescription())
|
|
.setType(FunctionType.forNumber(function.getFunctionType().getCode()))
|
|
.setType(FunctionType.forNumber(function.getFunctionType().getCode()))
|
|
.addAllInputFieldNames(function.getInputFieldNames())
|
|
.addAllInputFieldNames(function.getInputFieldNames())
|
|
- .addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
|
|
|
|
|
|
+ .addAllParams(ParamUtils.AssembleKvPair(params))
|
|
.build();
|
|
.build();
|
|
- return FunctionScore.newBuilder().addFunctions(schema).build();
|
|
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private io.milvus.grpc.FunctionScore convertOneFunction(CreateCollectionReq.Function function) {
|
|
|
|
+ FunctionSchema schema = convertFunctionSchema(function);
|
|
|
|
+ return io.milvus.grpc.FunctionScore.newBuilder().addFunctions(schema).build();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private io.milvus.grpc.FunctionScore convertFunctionScore(FunctionScore functionScore) {
|
|
|
|
+ io.milvus.grpc.FunctionScore.Builder builder = io.milvus.grpc.FunctionScore.newBuilder();
|
|
|
|
+ for (CreateCollectionReq.Function function : functionScore.getFunctions()) {
|
|
|
|
+ FunctionSchema schema = convertFunctionSchema(function);
|
|
|
|
+ builder.addFunctions(schema);
|
|
|
|
+ }
|
|
|
|
+ builder.addAllParams(ParamUtils.AssembleKvPair(functionScore.getParams()));
|
|
|
|
+ return builder.build();
|
|
}
|
|
}
|
|
|
|
|
|
public String getExprById(String primaryFieldName, List<?> ids) {
|
|
public String getExprById(String primaryFieldName, List<?> ids) {
|