Browse Source

Refine iterator code (#1180)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 7 months ago
parent
commit
89516f1f16

+ 0 - 0
.attach_pid4051998


+ 2 - 1
src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java

@@ -30,7 +30,8 @@ public class IteratorAdapterV2 {
                 .withOffset(queryIteratorReq.getOffset())
                 .withOffset(queryIteratorReq.getOffset())
                 .withLimit(queryIteratorReq.getLimit())
                 .withLimit(queryIteratorReq.getLimit())
                 .withIgnoreGrowing(queryIteratorReq.isIgnoreGrowing())
                 .withIgnoreGrowing(queryIteratorReq.isIgnoreGrowing())
-                .withBatchSize(queryIteratorReq.getBatchSize());
+                .withBatchSize(queryIteratorReq.getBatchSize())
+                .withReduceStopForBest(queryIteratorReq.isReduceStopForBest());
 
 
         if (queryIteratorReq.getConsistencyLevel() != null) {
         if (queryIteratorReq.getConsistencyLevel() != null) {
             builder.withConsistencyLevel(ConsistencyLevelEnum.valueOf(queryIteratorReq.getConsistencyLevel().name()));
             builder.withConsistencyLevel(ConsistencyLevelEnum.valueOf(queryIteratorReq.getConsistencyLevel().name()));

+ 24 - 15
src/main/java/io/milvus/orm/iterator/QueryIterator.java

@@ -19,10 +19,8 @@
 
 
 package io.milvus.orm.iterator;
 package io.milvus.orm.iterator;
 
 
-import io.milvus.grpc.DataType;
-import io.milvus.grpc.MilvusServiceGrpc;
-import io.milvus.grpc.QueryRequest;
-import io.milvus.grpc.QueryResults;
+import io.milvus.grpc.*;
+import io.milvus.param.Constant;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.collection.FieldType;
 import io.milvus.param.collection.FieldType;
 import io.milvus.param.dml.QueryIteratorParam;
 import io.milvus.param.dml.QueryIteratorParam;
@@ -98,7 +96,7 @@ public class QueryIterator {
     // perform a query to get the first time stamp check point
     // perform a query to get the first time stamp check point
     // the time stamp will be input for the next query to skip something
     // the time stamp will be input for the next query to skip something
     private void setupTsByRequest() {
     private void setupTsByRequest() {
-        QueryResults response = getQueryResultsWrapper(expr, 0L, 1L, 0L);
+        QueryResults response = executeQuery(expr, 0L, 1L, 0L);
         if (response.getSessionTs() <= 0) {
         if (response.getSessionTs() <= 0) {
             logger.warn("Failed to get mvccTs from milvus server, use client-side ts instead");
             logger.warn("Failed to get mvccTs from milvus server, use client-side ts instead");
             // fall back to latest session ts by local time
             // fall back to latest session ts by local time
@@ -116,7 +114,7 @@ public class QueryIterator {
             return;
             return;
         }
         }
 
 
-        QueryResults response = getQueryResultsWrapper(expr, 0L, offset, this.sessionTs);
+        QueryResults response = executeQuery(expr, 0L, offset, this.sessionTs);
         QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
         QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
         List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
         List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
         int resultIndex = Math.min(res.size(), (int) offset);
         int resultIndex = Math.min(res.size(), (int) offset);
@@ -135,7 +133,7 @@ public class QueryIterator {
             iteratorCache.releaseCache(cacheIdInUse);
             iteratorCache.releaseCache(cacheIdInUse);
             String currentExpr = setupNextExpr();
             String currentExpr = setupNextExpr();
             logger.debug("Query iterator next expression: " + currentExpr);
             logger.debug("Query iterator next expression: " + currentExpr);
-            QueryResults response = getQueryResultsWrapper(currentExpr, offset, batchSize, this.sessionTs);
+            QueryResults response = executeQuery(currentExpr, offset, batchSize, this.sessionTs);
             QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
             QueryResultsWrapper queryWrapper = new QueryResultsWrapper(response);
             List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
             List<QueryResultsWrapper.RowRecord> res = queryWrapper.getRowRecords();
             maybeCache(res);
             maybeCache(res);
@@ -199,7 +197,7 @@ public class QueryIterator {
         return ret != null && ret.size() >= batchSize;
         return ret != null && ret.size() >= batchSize;
     }
     }
 
 
-    private QueryResults getQueryResultsWrapper(String expr, long offset, long limit, long ts) {
+    private QueryResults executeQuery(String expr, long offset, long limit, long ts) {
         QueryParam queryParam = QueryParam.newBuilder()
         QueryParam queryParam = QueryParam.newBuilder()
                 .withDatabaseName(queryIteratorParam.getDatabaseName())
                 .withDatabaseName(queryIteratorParam.getDatabaseName())
                 .withCollectionName(queryIteratorParam.getCollectionName())
                 .withCollectionName(queryIteratorParam.getCollectionName())
@@ -210,20 +208,31 @@ public class QueryIterator {
                 .withOffset(offset)
                 .withOffset(offset)
                 .withLimit(limit)
                 .withLimit(limit)
                 .withIgnoreGrowing(queryIteratorParam.isIgnoreGrowing())
                 .withIgnoreGrowing(queryIteratorParam.isIgnoreGrowing())
-                .withReduceStopForBest(queryIteratorParam.isReduceStopForBest())
-                .withIterator(Boolean.TRUE)
                 .build();
                 .build();
 
 
         QueryRequest queryRequest = ParamUtils.convertQueryParam(queryParam);
         QueryRequest queryRequest = ParamUtils.convertQueryParam(queryParam);
+        QueryRequest.Builder builder = queryRequest.toBuilder();
+        // reduce stop for best
+        builder.addQueryParams(KeyValuePair.newBuilder()
+                .setKey(Constant.REDUCE_STOP_FOR_BEST)
+                .setValue(String.valueOf(queryIteratorParam.isReduceStopForBest()))
+                .build());
+
+        // iterator
+        builder.addQueryParams(KeyValuePair.newBuilder()
+                .setKey(Constant.ITERATOR_FIELD)
+                .setValue(String.valueOf(Boolean.TRUE))
+                .build());
+
         // pass the session ts to query interface
         // pass the session ts to query interface
-        if (ts > 0) {
-            queryRequest = queryRequest.toBuilder().setGuaranteeTimestamp(ts).build();
-        }
-        QueryResults response = blockingStub.query(queryRequest);
+        builder.setGuaranteeTimestamp(ts).build();
 
 
+        // set default consistency level
+        builder.setUseDefaultConsistency(true);
+
+        QueryResults response = blockingStub.query(builder.build());
         String title = String.format("QueryRequest collectionName:%s", queryIteratorParam.getCollectionName());
         String title = String.format("QueryRequest collectionName:%s", queryIteratorParam.getCollectionName());
         rpcUtils.handleResponse(title, response.getStatus());
         rpcUtils.handleResponse(title, response.getStatus());
-
         return response;
         return response;
     }
     }
 }
 }

+ 18 - 11
src/main/java/io/milvus/orm/iterator/SearchIterator.java

@@ -8,6 +8,7 @@ import io.milvus.common.utils.ExceptionUtils;
 import io.milvus.common.utils.JsonUtils;
 import io.milvus.common.utils.JsonUtils;
 import io.milvus.exception.ParamException;
 import io.milvus.exception.ParamException;
 import io.milvus.grpc.*;
 import io.milvus.grpc.*;
+import io.milvus.param.Constant;
 import io.milvus.param.MetricType;
 import io.milvus.param.MetricType;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.collection.FieldType;
 import io.milvus.param.collection.FieldType;
@@ -172,7 +173,7 @@ public class SearchIterator {
     }
     }
 
 
     private void initSearchIterator() {
     private void initSearchIterator() {
-        SearchResults response = executeNextSearch(params, expr, false, 0L);
+        SearchResults response = executeSearch(params, expr, false, 0L);
         if (response.getSessionTs() <= 0) {
         if (response.getSessionTs() <= 0) {
             logger.warn("Failed to get mvccTs from milvus server, use client-side ts instead");
             logger.warn("Failed to get mvccTs from milvus server, use client-side ts instead");
             // fall back to latest session ts by local time
             // fall back to latest session ts by local time
@@ -245,7 +246,7 @@ public class SearchIterator {
         }
         }
     }
     }
 
 
-    private SearchResults executeNextSearch(Map<String, Object> params, String nextExpr, boolean toExtendBatch, long ts) {
+    private SearchResults executeSearch(Map<String, Object> params, String nextExpr, boolean toExtendBatch, long ts) {
         SearchParam.Builder searchParamBuilder = SearchParam.newBuilder()
         SearchParam.Builder searchParamBuilder = SearchParam.newBuilder()
                 .withDatabaseName(searchIteratorParam.getDatabaseName())
                 .withDatabaseName(searchIteratorParam.getDatabaseName())
                 .withCollectionName(searchIteratorParam.getCollectionName())
                 .withCollectionName(searchIteratorParam.getCollectionName())
@@ -258,9 +259,7 @@ public class SearchIterator {
                 .withRoundDecimal(searchIteratorParam.getRoundDecimal())
                 .withRoundDecimal(searchIteratorParam.getRoundDecimal())
                 .withParams(JsonUtils.toJson(params))
                 .withParams(JsonUtils.toJson(params))
                 .withMetricType(MetricType.valueOf(searchIteratorParam.getMetricType()))
                 .withMetricType(MetricType.valueOf(searchIteratorParam.getMetricType()))
-                .withIgnoreGrowing(searchIteratorParam.isIgnoreGrowing())
-                .withIterator(Boolean.TRUE)
-                ;
+                .withIgnoreGrowing(searchIteratorParam.isIgnoreGrowing());
 
 
         if (!StringUtils.isNullOrEmpty(searchIteratorParam.getGroupByFieldName())) {
         if (!StringUtils.isNullOrEmpty(searchIteratorParam.getGroupByFieldName())) {
             searchParamBuilder.withGroupByFieldName(searchIteratorParam.getGroupByFieldName());
             searchParamBuilder.withGroupByFieldName(searchIteratorParam.getGroupByFieldName());
@@ -268,15 +267,23 @@ public class SearchIterator {
         fillVectorsByPlType(searchParamBuilder);
         fillVectorsByPlType(searchParamBuilder);
 
 
         SearchRequest searchRequest = ParamUtils.convertSearchParam(searchParamBuilder.build());
         SearchRequest searchRequest = ParamUtils.convertSearchParam(searchParamBuilder.build());
+        SearchRequest.Builder builder = searchRequest.toBuilder();
+        // iterator
+        builder.addSearchParams(
+                KeyValuePair.newBuilder()
+                        .setKey(Constant.ITERATOR_FIELD)
+                        .setValue(String.valueOf(Boolean.TRUE))
+                        .build());
+
         // pass the session ts to search interface
         // pass the session ts to search interface
-        if (ts > 0) {
-            searchRequest = searchRequest.toBuilder().setGuaranteeTimestamp(ts).build();
-        }
-        SearchResults response = blockingStub.search(searchRequest);
+        builder.setGuaranteeTimestamp(ts).build();
 
 
+        // set default consistency level
+        builder.setUseDefaultConsistency(true);
+
+        SearchResults response = blockingStub.search(builder.build());
         String title = String.format("SearchRequest collectionName:%s", searchIteratorParam.getCollectionName());
         String title = String.format("SearchRequest collectionName:%s", searchIteratorParam.getCollectionName());
         rpcUtils.handleResponse(title, response.getStatus());
         rpcUtils.handleResponse(title, response.getStatus());
-
         return response;
         return response;
     }
     }
 
 
@@ -388,7 +395,7 @@ public class SearchIterator {
         while (true) {
         while (true) {
             Map<String, Object> nextParams = nextParams(coefficient);
             Map<String, Object> nextParams = nextParams(coefficient);
             String nextExpr = filteredDuplicatedResultExpr(expr);
             String nextExpr = filteredDuplicatedResultExpr(expr);
-            SearchResults response = executeNextSearch(nextParams, nextExpr, true, this.sessionTs);
+            SearchResults response = executeSearch(nextParams, nextExpr, true, this.sessionTs);
             SearchResultsWrapper searchResultsWrapper = new SearchResultsWrapper(response.getResults());
             SearchResultsWrapper searchResultsWrapper = new SearchResultsWrapper(response.getResults());
             updateFilteredIds(searchResultsWrapper);
             updateFilteredIds(searchResultsWrapper);
             List<QueryResultsWrapper.RowRecord> newPage = searchResultsWrapper.getRowRecords(0);
             List<QueryResultsWrapper.RowRecord> newPage = searchResultsWrapper.getRowRecords(0);

+ 0 - 23
src/main/java/io/milvus/param/ParamUtils.java

@@ -803,11 +803,6 @@ public class ParamUtils {
                         KeyValuePair.newBuilder()
                         KeyValuePair.newBuilder()
                                 .setKey(Constant.IGNORE_GROWING)
                                 .setKey(Constant.IGNORE_GROWING)
                                 .setValue(String.valueOf(requestParam.isIgnoreGrowing()))
                                 .setValue(String.valueOf(requestParam.isIgnoreGrowing()))
-                                .build())
-                .addSearchParams(
-                        KeyValuePair.newBuilder()
-                                .setKey(Constant.ITERATOR_FIELD)
-                                .setValue(String.valueOf(requestParam.isIterator()))
                                 .build());
                                 .build());
 
 
         if (!Objects.equals(requestParam.getMetricType(), MetricType.None.name())) {
         if (!Objects.equals(requestParam.getMetricType(), MetricType.None.name())) {
@@ -1003,12 +998,6 @@ public class ParamUtils {
     public static QueryRequest convertQueryParam(@NonNull QueryParam requestParam) {
     public static QueryRequest convertQueryParam(@NonNull QueryParam requestParam) {
         boolean useDefaultConsistency = (requestParam.getConsistencyLevel() == null);
         boolean useDefaultConsistency = (requestParam.getConsistencyLevel() == null);
         long guaranteeTimestamp = getGuaranteeTimestamp(requestParam.getConsistencyLevel(), requestParam.getCollectionName());
         long guaranteeTimestamp = getGuaranteeTimestamp(requestParam.getConsistencyLevel(), requestParam.getCollectionName());
-        // special logic for iterator
-        // don't pass guaranteeTimestamp for iterator, the query() interface might return empty list.
-        if (requestParam.isIterator()) {
-            useDefaultConsistency = true;
-            guaranteeTimestamp = 0L;
-        }
         QueryRequest.Builder builder = QueryRequest.newBuilder()
         QueryRequest.Builder builder = QueryRequest.newBuilder()
                 .setCollectionName(requestParam.getCollectionName())
                 .setCollectionName(requestParam.getCollectionName())
                 .addAllPartitionNames(requestParam.getPartitionNames())
                 .addAllPartitionNames(requestParam.getPartitionNames())
@@ -1052,18 +1041,6 @@ public class ParamUtils {
                 .setValue(String.valueOf(requestParam.isIgnoreGrowing()))
                 .setValue(String.valueOf(requestParam.isIgnoreGrowing()))
                 .build());
                 .build());
 
 
-        // reduce stop for best
-        builder.addQueryParams(KeyValuePair.newBuilder()
-                        .setKey(Constant.REDUCE_STOP_FOR_BEST)
-                        .setValue(String.valueOf(requestParam.isReduceStopForBest()))
-                .build());
-
-        // iterator
-        builder.addQueryParams(KeyValuePair.newBuilder()
-                        .setKey(Constant.ITERATOR_FIELD)
-                        .setValue(String.valueOf(requestParam.isIterator()))
-                .build());
-
         return builder.build();
         return builder.build();
     }
     }
 
 

+ 0 - 28
src/main/java/io/milvus/param/dml/QueryParam.java

@@ -49,8 +49,6 @@ public class QueryParam {
     private final long offset;
     private final long offset;
     private final long limit;
     private final long limit;
     private final boolean ignoreGrowing;
     private final boolean ignoreGrowing;
-    private final boolean reduceStopForBest;
-    private final boolean iterator;
 
 
     private QueryParam(@NonNull Builder builder) {
     private QueryParam(@NonNull Builder builder) {
         this.databaseName = builder.databaseName;
         this.databaseName = builder.databaseName;
@@ -65,8 +63,6 @@ public class QueryParam {
         this.offset = builder.offset;
         this.offset = builder.offset;
         this.limit = builder.limit;
         this.limit = builder.limit;
         this.ignoreGrowing = builder.ignoreGrowing;
         this.ignoreGrowing = builder.ignoreGrowing;
-        this.reduceStopForBest = builder.reduceStopForBest;
-        this.iterator = builder.iterator;
     }
     }
 
 
     public static Builder newBuilder() {
     public static Builder newBuilder() {
@@ -89,8 +85,6 @@ public class QueryParam {
         private Long offset = 0L;
         private Long offset = 0L;
         private Long limit = 0L;
         private Long limit = 0L;
         private Boolean ignoreGrowing = Boolean.FALSE;
         private Boolean ignoreGrowing = Boolean.FALSE;
-        private Boolean reduceStopForBest = Boolean.FALSE;
-        private Boolean iterator = Boolean.FALSE;
 
 
         private Builder() {
         private Builder() {
         }
         }
@@ -224,28 +218,6 @@ public class QueryParam {
             return this;
             return this;
         }
         }
 
 
-        /**
-         * Adjust the query using iterators to handle offsets more efficiently during the Reduce step. Default is False.
-         *
-         * @param reduceStopForBest <code>Boolean.TRUE</code> ignore, Boolean.FALSE is not
-         * @return <code>Builder</code>
-         */
-        public Builder withReduceStopForBest(@NonNull Boolean reduceStopForBest) {
-            this.reduceStopForBest = reduceStopForBest;
-            return this;
-        }
-
-        /**
-         * Optimizing specifically for iterators can yield correct data results. Default is False.
-         *
-         * @param iterator <code>Boolean.TRUE</code> ignore, Boolean.FALSE is not
-         * @return <code>Builder</code>
-         */
-        public Builder withIterator(@NonNull Boolean iterator) {
-            this.iterator = iterator;
-            return this;
-        }
-
         /**
         /**
          * Verifies parameters and creates a new {@link QueryParam} instance.
          * Verifies parameters and creates a new {@link QueryParam} instance.
          *
          *

+ 0 - 14
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -59,7 +59,6 @@ public class SearchParam {
     private final Integer groupSize;
     private final Integer groupSize;
     private final Boolean strictGroupSize;
     private final Boolean strictGroupSize;
     private final PlaceholderType plType;
     private final PlaceholderType plType;
-    private final boolean iterator;
 
 
     private SearchParam(@NonNull Builder builder) {
     private SearchParam(@NonNull Builder builder) {
         this.databaseName = builder.databaseName;
         this.databaseName = builder.databaseName;
@@ -83,7 +82,6 @@ public class SearchParam {
         this.groupSize = builder.groupSize;
         this.groupSize = builder.groupSize;
         this.strictGroupSize = builder.strictGroupSize;
         this.strictGroupSize = builder.strictGroupSize;
         this.plType = builder.plType;
         this.plType = builder.plType;
-        this.iterator = builder.iterator;
     }
     }
 
 
     public static Builder newBuilder() {
     public static Builder newBuilder() {
@@ -114,7 +112,6 @@ public class SearchParam {
         private String groupByFieldName;
         private String groupByFieldName;
         private Integer groupSize = null;
         private Integer groupSize = null;
         private Boolean strictGroupSize = null;
         private Boolean strictGroupSize = null;
-        private Boolean iterator = Boolean.FALSE;
 
 
         // plType is used to distinct vector type
         // plType is used to distinct vector type
         // for Float16Vector/BFloat16Vector and BinaryVector, user inputs ByteBuffer
         // for Float16Vector/BFloat16Vector and BinaryVector, user inputs ByteBuffer
@@ -406,17 +403,6 @@ public class SearchParam {
             return this;
             return this;
         }
         }
 
 
-        /**
-         * Optimizing specifically for iterators can yield correct data results. Default is False.
-         *
-         * @param iterator <code>Boolean.TRUE</code> ignore, Boolean.FALSE is not
-         * @return <code>Builder</code>
-         */
-        public Builder withIterator(@NonNull Boolean iterator) {
-            this.iterator = iterator;
-            return this;
-        }
-
         /**
         /**
          * Verifies parameters and creates a new {@link SearchParam} instance.
          * Verifies parameters and creates a new {@link SearchParam} instance.
          *
          *

+ 2 - 0
src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java

@@ -29,4 +29,6 @@ public class QueryIteratorReq {
     private boolean ignoreGrowing = false;
     private boolean ignoreGrowing = false;
     @Builder.Default
     @Builder.Default
     private long batchSize = 1000L;
     private long batchSize = 1000L;
+    @Builder.Default
+    private boolean reduceStopForBest = false;
 }
 }