Răsfoiți Sursa

Support expression template values (#1150)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
Co-authored-by: cai.zhang <cai.zhang@zilliz.com>
groot 8 luni în urmă
părinte
comite
cf92520494

+ 10 - 4
src/main/java/io/milvus/v2/service/vector/VectorService.java

@@ -41,6 +41,7 @@ import org.slf4j.LoggerFactory;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentHashMap;
 
 
 public class VectorService extends BaseService {
 public class VectorService extends BaseService {
@@ -224,12 +225,17 @@ public class VectorService extends BaseService {
         if (request.getFilter() == null) {
         if (request.getFilter() == null) {
             request.setFilter(vectorUtils.getExprById(respR.getPrimaryFieldName(), request.getIds()));
             request.setFilter(vectorUtils.getExprById(respR.getPrimaryFieldName(), request.getIds()));
         }
         }
-        DeleteRequest deleteRequest = DeleteRequest.newBuilder()
+        DeleteRequest.Builder builder = DeleteRequest.newBuilder()
                 .setCollectionName(request.getCollectionName())
                 .setCollectionName(request.getCollectionName())
                 .setPartitionName(request.getPartitionName())
                 .setPartitionName(request.getPartitionName())
-                .setExpr(request.getFilter())
-                .build();
-        MutationResult response = blockingStub.delete(deleteRequest);
+                .setExpr(request.getFilter());
+        if (request.getFilter() != null && !request.getFilter().isEmpty()) {
+            Map<String, Object> filterTemplateValues = request.getFilterTemplateValues();
+            filterTemplateValues.forEach((key, value)->{
+                builder.putExprTemplateValues(key, vectorUtils.deduceAndCreateTemplateValue(value));
+            });
+        }
+        MutationResult response = blockingStub.delete(builder.build());
         rpcUtils.handleResponse(title, response.getStatus());
         rpcUtils.handleResponse(title, response.getStatus());
         GTsDict.getInstance().updateCollectionTs(request.getCollectionName(), response.getTimestamp());
         GTsDict.getInstance().updateCollectionTs(request.getCollectionName(), response.getTimestamp());
         return DeleteResp.builder()
         return DeleteResp.builder()

+ 13 - 0
src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java

@@ -23,7 +23,9 @@ import lombok.Builder;
 import lombok.Data;
 import lombok.Data;
 import lombok.experimental.SuperBuilder;
 import lombok.experimental.SuperBuilder;
 
 
+import java.util.HashMap;
 import java.util.List;
 import java.util.List;
+import java.util.Map;
 
 
 @Data
 @Data
 @SuperBuilder
 @SuperBuilder
@@ -33,4 +35,15 @@ public class DeleteReq {
     private String partitionName = "";
     private String partitionName = "";
     private String filter;
     private String filter;
     private List<Object> ids;
     private List<Object> ids;
+
+    // Expression template, to improve expression parsing performance in complicated list
+    // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
+    // The long list of city will increase the time cost to parse this expression.
+    // So, we provide exprTemplateValues for this purpose, user can set filter like this:
+    //     filter = "pk > {age} and city in {city}"
+    //     filterTemplateValues = Map{"age": 3, "city": List<String>{"beijing", "shanghai", ......}}
+    // Valid value of this map can be:
+    //     Boolean, Long, Double, String, List<Boolean>, List<Long>, List<Double>, List<String>
+    @Builder.Default
+    private Map<String, Object> filterTemplateValues = new HashMap<>();
 }
 }

+ 12 - 3
src/main/java/io/milvus/v2/service/vector/request/QueryReq.java

@@ -24,9 +24,7 @@ import lombok.Builder;
 import lombok.Data;
 import lombok.Data;
 import lombok.experimental.SuperBuilder;
 import lombok.experimental.SuperBuilder;
 
 
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
+import java.util.*;
 
 
 @Data
 @Data
 @SuperBuilder
 @SuperBuilder
@@ -42,4 +40,15 @@ public class QueryReq {
     private ConsistencyLevel consistencyLevel = null;
     private ConsistencyLevel consistencyLevel = null;
     private long offset;
     private long offset;
     private long limit;
     private long limit;
+
+    // Expression template, to improve expression parsing performance in complicated list
+    // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
+    // The long list of city will increase the time cost to parse this expression.
+    // So, we provide exprTemplateValues for this purpose, user can set filter like this:
+    //     filter = "pk > {age} and city in {city}"
+    //     filterTemplateValues = Map{"age": 3, "city": List<String>{"beijing", "shanghai", ......}}
+    // Valid value of this map can be:
+    //     Boolean, Long, Double, String, List<Boolean>, List<Long>, List<Double>, List<String>
+    @Builder.Default
+    private Map<String, Object> filterTemplateValues = new HashMap<>();
 }
 }

+ 11 - 0
src/main/java/io/milvus/v2/service/vector/request/SearchReq.java

@@ -57,4 +57,15 @@ public class SearchReq {
     private ConsistencyLevel consistencyLevel = null;
     private ConsistencyLevel consistencyLevel = null;
     private boolean ignoreGrowing;
     private boolean ignoreGrowing;
     private String groupByFieldName;
     private String groupByFieldName;
+
+    // Expression template, to improve expression parsing performance in complicated list
+    // Assume user has a filter = "pk > 3 and city in ["beijing", "shanghai", ......]
+    // The long list of city will increase the time cost to parse this expression.
+    // So, we provide exprTemplateValues for this purpose, user can set filter like this:
+    //     filter = "pk > {age} and city in {city}"
+    //     filterTemplateValues = Map{"age": 3, "city": List<String>{"beijing", "shanghai", ......}}
+    // Valid value of this map can be:
+    //     Boolean, Long, Double, String, List<Boolean>, List<Long>, List<Double>, List<String>
+    @Builder.Default
+    private Map<String, Object> filterTemplateValues = new HashMap<>();
 }
 }

+ 61 - 0
src/main/java/io/milvus/v2/utils/VectorUtils.java

@@ -44,6 +44,12 @@ public class VectorUtils {
                 .addAllPartitionNames(request.getPartitionNames())
                 .addAllPartitionNames(request.getPartitionNames())
                 .addAllOutputFields(request.getOutputFields())
                 .addAllOutputFields(request.getOutputFields())
                 .setExpr(request.getFilter());
                 .setExpr(request.getFilter());
+        if (request.getFilter() != null && !request.getFilter().isEmpty()) {
+            Map<String, Object> filterTemplateValues = request.getFilterTemplateValues();
+            filterTemplateValues.forEach((key, value)->{
+                builder.putExprTemplateValues(key, deduceAndCreateTemplateValue(value));
+            });
+        }
 
 
         // a new parameter from v2.2.9, if user didn't specify consistency level, set this parameter to true
         // a new parameter from v2.2.9, if user didn't specify consistency level, set this parameter to true
         if (request.getConsistencyLevel() == null) {
         if (request.getConsistencyLevel() == null) {
@@ -180,6 +186,10 @@ public class VectorUtils {
         builder.setDslType(DslType.BoolExprV1);
         builder.setDslType(DslType.BoolExprV1);
         if (request.getFilter() != null && !request.getFilter().isEmpty()) {
         if (request.getFilter() != null && !request.getFilter().isEmpty()) {
             builder.setDsl(request.getFilter());
             builder.setDsl(request.getFilter());
+            Map<String, Object> filterTemplateValues = request.getFilterTemplateValues();
+            filterTemplateValues.forEach((key, value)->{
+                builder.putExprTemplateValues(key, deduceAndCreateTemplateValue(value));
+            });
         }
         }
 
 
         long guaranteeTimestamp = getGuaranteeTimestamp(request.getConsistencyLevel(), request.getCollectionName());
         long guaranteeTimestamp = getGuaranteeTimestamp(request.getConsistencyLevel(), request.getCollectionName());
@@ -195,6 +205,57 @@ public class VectorUtils {
         return builder.build();
         return builder.build();
     }
     }
 
 
+    public static TemplateValue deduceAndCreateTemplateValue(Object value) {
+        if (value instanceof Boolean) {
+            return TemplateValue.newBuilder()
+                    .setBoolVal((Boolean)value)
+                    .setType(DataType.Bool)
+                    .build();
+        } else if (value instanceof Integer) {
+           return TemplateValue.newBuilder()
+                   .setInt64Val((Integer)value)
+                   .setType(DataType.Int64)
+                   .build();
+        } else if (value instanceof Double) {
+            return TemplateValue.newBuilder()
+                    .setFloatVal((Double)value)
+                    .setType(DataType.Double)
+                    .build();
+        } else if (value instanceof String) {
+            return TemplateValue.newBuilder()
+                    .setStringVal((String)value)
+                    .setType(DataType.VarChar)
+                    .build();
+        } else if (value instanceof List) {
+            // TemplateArrayValue and TemplateValue can nest each other
+            // The element_type of TemplateArrayValue is deduced by its elements:
+            //   1. if all the elements are the same type, element_type is the first element's type
+            //   2. if not the same type, element_type is DataType.JSON
+            List<Object> array = (List<Object>)value;
+            TemplateArrayValue.Builder builder = TemplateArrayValue.newBuilder();
+            DataType lastType = DataType.UNRECOGNIZED;
+            boolean sameType = true;
+            for (Object obj : array) {
+                TemplateValue tv = deduceAndCreateTemplateValue(obj);
+                builder.addArray(tv);
+                if (sameType && lastType != DataType.UNRECOGNIZED && lastType != tv.getType()) {
+                    sameType = false;
+                }
+                lastType = tv.getType();
+            }
+            DataType arrayType = sameType ? lastType : DataType.JSON;
+            builder.setElementType(arrayType);
+            builder.setSameType(sameType);
+
+            return TemplateValue.newBuilder()
+                    .setArrayVal(builder.build())
+                    .setType(builder.getElementType())
+                    .build();
+        } else {
+            throw new ParamException("Unsupported value type for expression template.");
+        }
+    }
+
     public static SearchRequest convertAnnSearchParam(@NonNull AnnSearchReq annSearchReq,
     public static SearchRequest convertAnnSearchParam(@NonNull AnnSearchReq annSearchReq,
                                                       ConsistencyLevel consistencyLevel) {
                                                       ConsistencyLevel consistencyLevel) {
         SearchRequest.Builder builder = SearchRequest.newBuilder();
         SearchRequest.Builder builder = SearchRequest.newBuilder();

+ 1 - 1
src/main/milvus-proto

@@ -1 +1 @@
-Subproject commit ef9b8fd69497e5dc0c746057436b6411ede6f912
+Subproject commit 4d5c88b00cf7280b17542940d3b49041f1c92f66

+ 32 - 3
src/test/java/io/milvus/v2/service/vector/VectorTest.java

@@ -29,9 +29,7 @@ import org.junit.jupiter.api.Test;
 import org.slf4j.Logger;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.slf4j.LoggerFactory;
 
 
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
+import java.util.*;
 
 
 class VectorTest extends BaseTest {
 class VectorTest extends BaseTest {
 
 
@@ -105,6 +103,37 @@ class VectorTest extends BaseTest {
         logger.info(statusR.toString());
         logger.info(statusR.toString());
     }
     }
 
 
+    @Test
+    void testSearchWithTemplateExpression() {
+        List<Float> vectorList = new ArrayList<>();
+        vectorList.add(1.0f);
+        vectorList.add(2.0f);
+
+        Map<String, Map<String, Object>> expressionTemplateValues = new HashMap<>();
+        Map<String, Object> params = new HashMap<>();
+        params.put("max", 10);
+        expressionTemplateValues.put("id < {max}", params);
+
+        List<Object> list = Arrays.asList(1, 2, 3);
+        Map<String, Object> params2 = new HashMap<>();
+        params2.put("list", list);
+        expressionTemplateValues.put("id in {list}", params2);
+
+        expressionTemplateValues.forEach((key, value) -> {
+            SearchReq request = SearchReq.builder()
+                    .collectionName("test")
+                    .data(Collections.singletonList(new FloatVec(vectorList)))
+                    .topK(10)
+                    .offset(0L)
+                    .filter(key)
+                    .filterTemplateValues(value)
+                    .build();
+            SearchResp statusR = client_v2.search(request);
+            logger.info(statusR.toString());
+            System.out.println(statusR.toString());
+        });
+    }
+
     @Test
     @Test
     void testDelete() {
     void testDelete() {
         DeleteReq request = DeleteReq.builder()
         DeleteReq request = DeleteReq.builder()