瀏覽代碼

Fix a bug of QueryIterator that offset cannot exceed 16384 (#1582)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 月之前
父節點
當前提交
11521fa9b1

+ 24 - 9
sdk-core/src/main/java/io/milvus/orm/iterator/QueryIterator.java

@@ -33,9 +33,11 @@ import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
 import java.util.List;
 
 import static io.milvus.param.Constant.NO_CACHE_ID;
+import static io.milvus.param.Constant.MAX_BATCH_SIZE;
 import static io.milvus.param.Constant.UNLIMITED;
 
 public class QueryIterator {
@@ -96,7 +98,7 @@ public class QueryIterator {
     // perform a query to get the first time stamp check point
     // the time stamp will be input for the next query to skip something
     private void setupTsByRequest() {
-        QueryResults response = executeQuery(expr, 0L, 1L, 0L);
+        QueryResults response = executeQuery(expr, 0L, 1L, 0L, true);
         if (response.getSessionTs() <= 0) {
             logger.warn("Failed to get mvccTs from milvus server, use client-side ts instead");
             // fall back to latest session ts by local time
@@ -114,11 +116,19 @@ public class QueryIterator {
             return;
         }
 
-        QueryResults response = executeQuery(expr, 0L, offset, this.sessionTs);
-        QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
-        List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
-        int resultIndex = Math.min(res.size(), (int) offset);
-        updateCursor(res.subList(0, resultIndex));
+        long currentOffset = offset;
+        while (currentOffset > 0) {
+            long limit = Math.min(MAX_BATCH_SIZE, currentOffset);
+            String currentExpr = setupNextExpr();
+            QueryResults response = executeQuery(currentExpr, 0L, limit, this.sessionTs, true);
+            QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
+            List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
+            if (res.isEmpty()) {
+                break;
+            }
+            updateCursor(res);
+            currentOffset -= res.size();
+        }
         offset = 0;
     }
 
@@ -133,7 +143,7 @@ public class QueryIterator {
             iteratorCache.releaseCache(cacheIdInUse);
             String currentExpr = setupNextExpr();
             logger.debug("Query iterator next expression: " + currentExpr);
-            QueryResults response = executeQuery(currentExpr, offset, batchSize, this.sessionTs);
+            QueryResults response = executeQuery(currentExpr, offset, batchSize, this.sessionTs, false);
             QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
             List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
             maybeCache(res);
@@ -197,13 +207,18 @@ public class QueryIterator {
         return ret != null && ret.size() >= batchSize;
     }
 
-    private QueryResults executeQuery(String expr, long offset, long limit, long ts) {
+    private QueryResults executeQuery(String expr, long offset, long limit, long ts, boolean isSeek) {
+        // for seeking offset, no need to return output fields
+        List<String> outputFields = new ArrayList<>();
+        if (!isSeek) {
+            outputFields = queryIteratorParam.getOutFields();
+        }
         QueryParam queryParam = QueryParam.newBuilder()
                 .withDatabaseName(queryIteratorParam.getDatabaseName())
                 .withCollectionName(queryIteratorParam.getCollectionName())
                 .withConsistencyLevel(queryIteratorParam.getConsistencyLevel())
                 .withPartitionNames(queryIteratorParam.getPartitionNames())
-                .withOutFields(queryIteratorParam.getOutFields())
+                .withOutFields(outputFields)
                 .withExpr(expr)
                 .withOffset(offset)
                 .withLimit(limit)

+ 13 - 9
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1761,7 +1761,7 @@ class MilvusClientV2DockerTest {
         client.createCollection(requestCreate);
 
         // insert rows
-        long count = 10000;
+        long count = 20000;
         List<JsonObject> data = generateRandomData(collectionSchema, count);
         InsertResp insertResp = client.insert(InsertReq.builder()
                 .collectionName(randomCollectionName)
@@ -1847,13 +1847,15 @@ class MilvusClientV2DockerTest {
         Assertions.assertTrue(counter > 0);
 
         // query iterator
+        long from = 17777;
+        long to = 18000;
         QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder()
                 .collectionName(randomCollectionName)
-                .expr("int64_field < 300")
+                .expr("int64_field < " + String.valueOf(to))
                 .outputFields(Lists.newArrayList("*"))
                 .batchSize(50L)
-                .offset(5)
-                .limit(400)
+                .offset(from)
+                .limit(4000)
                 .consistencyLevel(ConsistencyLevel.EVENTUALLY)
                 .build());
 
@@ -1867,6 +1869,7 @@ class MilvusClientV2DockerTest {
             }
 
             for (QueryResultsWrapper.RowRecord record : res) {
+                Assertions.assertInstanceOf(Long.class, record.get("id"));
                 Assertions.assertInstanceOf(Boolean.class, record.get("bool_field"));
                 Assertions.assertInstanceOf(Integer.class, record.get("int8_field"));
                 Assertions.assertInstanceOf(Integer.class, record.get("int16_field"));
@@ -1882,8 +1885,9 @@ class MilvusClientV2DockerTest {
                 Assertions.assertInstanceOf(ByteBuffer.class, record.get("bfloat16_vector"));
                 Assertions.assertInstanceOf(SortedMap.class, record.get("sparse_vector"));
 
-                long int64Val = (long)record.get("int64_field");
-                Assertions.assertTrue(int64Val < 300L);
+                long int64Val = (long)record.get("id");
+                Assertions.assertTrue(int64Val >= from);
+                Assertions.assertTrue(int64Val < to);
 
                 String varcharVal = (String)record.get("varchar_field");
                 Assertions.assertTrue(varcharVal.startsWith("varchar_"));
@@ -1904,12 +1908,12 @@ class MilvusClientV2DockerTest {
                 Assertions.assertEquals(DIMENSION*2, bfloat16Vector.limit());
 
                 SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)record.get("sparse_vector");
-                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
+                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() < 20); // defined in generateSparseVector()
 
                 counter++;
             }
         }
-        Assertions.assertEquals(295, counter);
+        Assertions.assertEquals(to - from, counter);
 
         // search iterator V2
         SearchIteratorV2 searchIteratorV2 = client.searchIteratorV2(SearchIteratorReqV2.builder()
@@ -1969,7 +1973,7 @@ class MilvusClientV2DockerTest {
                 Assertions.assertEquals(DIMENSION*2, bfloat16Vector.limit());
 
                 SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)entity.get("sparse_vector");
-                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
+                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() < 20); // defined in generateSparseVector()
 
                 counter++;
             }