Browse Source

External function example for SearchIteratorV2 (#1330)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 month ago
parent
commit
fc13981bc7

+ 19 - 3
examples/src/main/java/io/milvus/v2/IteratorExample.java

@@ -43,6 +43,7 @@ import io.milvus.v2.service.vector.response.SearchResp;
 import org.apache.commons.lang3.StringUtils;
 
 import java.util.*;
+import java.util.function.Function;
 
 public class IteratorExample {
     private static final MilvusClientV2 client;
@@ -196,7 +197,8 @@ public class IteratorExample {
     // Search iterator V2
     // In SDK v2.5.6, we provide a new search iterator implementation. SearchIteratorV2 is recommended.
     // SearchIteratorV2 is faster than V1 by 20~30 percent, and the recall is a little better than V1.
-    private static void searchIteratorV2(String filter, Map<String, Object> params, int batchSize, int topK) {
+    private static void searchIteratorV2(String filter, Map<String, Object> params, int batchSize, int topK,
+                                         Function<List<SearchResp.SearchResult>, List<SearchResp.SearchResult>> externalFilterFunc) {
         System.out.println("\n========== searchIteratorV2() ==========");
         System.out.println(String.format("expr='%s', params='%s', batchSize=%d, topK=%d",
                 filter, params==null ? "" : params.toString(), batchSize, topK));
@@ -211,6 +213,7 @@ public class IteratorExample {
                 .topK(topK)
                 .metricType(IndexParam.MetricType.L2)
                 .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .externalFilterFunc(externalFilterFunc)
                 .build());
 
         System.out.println("SearchIteratorV2 results:");
@@ -236,10 +239,23 @@ public class IteratorExample {
         queryIterator("userID < 300",50, 5,400);
         searchIteratorV1("userAge > 50 &&userAge < 100", "{\"range_filter\": 15.0, \"radius\": 20.0}", 100, 500);
         searchIteratorV1("", "", 10, 99);
-        searchIteratorV2("userAge > 10 &&userAge < 20", null, 50, 100);
+        searchIteratorV2("userAge > 10 &&userAge < 20", null, 50, 120, null);
 
         Map<String,Object> extraParams = new HashMap<>();
         extraParams.put("radius",15.0);
-        searchIteratorV2("", extraParams, 50, 100);
+        searchIteratorV2("", extraParams, 50, 100, null);
+
+        // use external function to filter the result
+        Function<List<SearchResp.SearchResult>, List<SearchResp.SearchResult>> externalFilterFunc = (List<SearchResp.SearchResult> src)->{
+            List<SearchResp.SearchResult> newRes = new ArrayList<>();
+            for (SearchResp.SearchResult res : src) {
+                long id = (long)res.getId();
+                if (id%2 == 0) {
+                    newRes.add(res);
+                }
+            }
+            return newRes;
+        };
+        searchIteratorV2("userAge < 20", null, 50, 88, externalFilterFunc);
     }
 }

+ 4 - 2
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1570,12 +1570,14 @@ class MilvusClientV2DockerTest {
                 Assertions.assertEquals(DIMENSION*2, bfloat16Vector.limit());
 
                 SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)entity.get("sparse_vector");
-                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
+                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() < 20); // defined in generateSparseVector()
 
                 counter++;
             }
         }
-        Assertions.assertEquals((int)count - 50, counter);
+        // search iterator could not ensure that all the entities can be retrieved
+        // expect count is 9950, but sometimes it returns 9949 or 9948
+        Assertions.assertTrue(counter > ((int)count - 55) && counter <= ((int)count - 50));
 
         client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
     }