|
@@ -31,9 +31,11 @@ import io.milvus.param.Constant;
|
|
|
import io.milvus.param.ParamUtils;
|
|
|
import io.milvus.v2.exception.ErrorCode;
|
|
|
import io.milvus.v2.exception.MilvusClientException;
|
|
|
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
|
|
import io.milvus.v2.service.vector.request.*;
|
|
|
-import io.milvus.v2.service.vector.request.ranker.BaseRanker;
|
|
|
import io.milvus.v2.service.vector.request.data.*;
|
|
|
+import io.milvus.v2.service.vector.request.ranker.RRFRanker;
|
|
|
+import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
|
|
|
import lombok.NonNull;
|
|
|
import org.apache.commons.collections4.CollectionUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
@@ -279,6 +281,12 @@ public class VectorUtils {
|
|
|
builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode());
|
|
|
}
|
|
|
|
|
|
+ // set ranker, support reranking search result from v2.6.1
|
|
|
+ CreateCollectionReq.Function ranker = request.getRanker();
|
|
|
+ if (ranker != null) {
|
|
|
+ builder.setFunctionScore(convertFunctionScore(ranker));
|
|
|
+ }
|
|
|
+
|
|
|
return builder.build();
|
|
|
}
|
|
|
|
|
@@ -473,16 +481,25 @@ public class VectorUtils {
|
|
|
}
|
|
|
|
|
|
// set ranker
|
|
|
- BaseRanker ranker = request.getRanker();
|
|
|
- if (request.getRanker() == null) {
|
|
|
+ CreateCollectionReq.Function ranker = request.getRanker();
|
|
|
+ if (ranker == null) {
|
|
|
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Ranker is null.");
|
|
|
}
|
|
|
|
|
|
- // topK value is deprecated, always use "limit" to set the topK
|
|
|
- Map<String, String> props = ranker.getProperties();
|
|
|
+ Map<String, String> props = new HashMap<>();
|
|
|
props.put(Constant.LIMIT, String.valueOf(request.getLimit()));
|
|
|
props.put(Constant.ROUND_DECIMAL, String.valueOf(request.getRoundDecimal()));
|
|
|
props.put(Constant.OFFSET, String.valueOf(request.getOffset()));
|
|
|
+
|
|
|
+ if (ranker instanceof RRFRanker || ranker instanceof WeightedRanker) {
|
|
|
+ // old logic for RRF/Weighted ranker
|
|
|
+ Map<String, String> params = ranker.getParams();
|
|
|
+ props.putAll(params);
|
|
|
+ } else {
|
|
|
+ // new logic for Decay/Model ranker
|
|
|
+ builder.setFunctionScore(convertFunctionScore(ranker));
|
|
|
+ }
|
|
|
+
|
|
|
List<KeyValuePair> propertiesList = ParamUtils.AssembleKvPair(props);
|
|
|
if (CollectionUtils.isNotEmpty(propertiesList)) {
|
|
|
propertiesList.forEach(builder::addRankParams);
|
|
@@ -528,6 +545,17 @@ public class VectorUtils {
|
|
|
return builder.build();
|
|
|
}
|
|
|
|
|
|
+ private FunctionScore convertFunctionScore(CreateCollectionReq.Function function) {
|
|
|
+ FunctionSchema schema = FunctionSchema.newBuilder()
|
|
|
+ .setName(function.getName())
|
|
|
+ .setDescription(function.getDescription())
|
|
|
+ .setType(FunctionType.forNumber(function.getFunctionType().getCode()))
|
|
|
+ .addAllInputFieldNames(function.getInputFieldNames())
|
|
|
+ .addAllParams(ParamUtils.AssembleKvPair(function.getParams()))
|
|
|
+ .build();
|
|
|
+ return FunctionScore.newBuilder().addFunctions(schema).build();
|
|
|
+ }
|
|
|
+
|
|
|
public String getExprById(String primaryFieldName, List<?> ids) {
|
|
|
StringBuilder sb = new StringBuilder();
|
|
|
sb.append(primaryFieldName).append(" in [");
|