Переглянути джерело

Iterator for V2 (#945)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 10 місяців тому
батько
коміт
47739f5961
22 змінених файлів з 713 додано та 30 видалено
  1. 1 0
      .gitignore
  2. 1 1
      examples/main/java/io/milvus/v1/BinaryVectorExample.java
  3. 1 1
      examples/main/java/io/milvus/v1/Float16VectorExample.java
  4. 1 1
      examples/main/java/io/milvus/v1/IteratorExample.java
  5. 1 1
      examples/main/java/io/milvus/v1/SparseVectorExample.java
  6. 1 1
      examples/main/java/io/milvus/v2/Float16VectorExample.java
  7. 161 0
      examples/main/java/io/milvus/v2/IteratorExample.java
  8. 1 3
      examples/main/java/io/milvus/v2/SimpleExample.java
  9. 1 1
      src/main/java/io/milvus/common/utils/Float16Utils.java
  10. 130 0
      src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java
  11. 21 0
      src/main/java/io/milvus/orm/iterator/QueryIterator.java
  12. 27 9
      src/main/java/io/milvus/orm/iterator/SearchIterator.java
  13. 10 0
      src/main/java/io/milvus/response/FieldDataWrapper.java
  14. 3 3
      src/main/java/io/milvus/response/QueryResultsWrapper.java
  15. 3 3
      src/main/java/io/milvus/response/SearchResultsWrapper.java
  16. 15 2
      src/main/java/io/milvus/response/basic/RowRecordWrapper.java
  17. 23 0
      src/main/java/io/milvus/v2/client/MilvusClientV2.java
  18. 18 0
      src/main/java/io/milvus/v2/service/vector/VectorService.java
  19. 32 0
      src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java
  20. 43 0
      src/main/java/io/milvus/v2/service/vector/request/SearchIteratorReq.java
  21. 6 3
      src/test/java/io/milvus/client/MilvusClientDockerTest.java
  22. 213 1
      src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

+ 1 - 0
.gitignore

@@ -31,6 +31,7 @@ volumes/
 *.iml
 
 # Example files
+examples/bulk_writer
 examples/main/resources/tls/*
 !examples/main/resources/tls/gen.sh
 !examples/main/resources/tls/openssl.cnf

+ 1 - 1
examples/main/java/io/milvus/v1/BinaryVectorExample.java

@@ -37,7 +37,7 @@ import java.nio.ByteBuffer;
 import java.util.*;
 
 public class BinaryVectorExample {
-    private static final String COLLECTION_NAME = "java_sdk_example_sparse";
+    private static final String COLLECTION_NAME = "java_sdk_example_binary_vector";
     private static final String ID_FIELD = "id";
     private static final String VECTOR_FIELD = "vector";
 

+ 1 - 1
examples/main/java/io/milvus/v1/Float16VectorExample.java

@@ -36,7 +36,7 @@ import java.util.*;
 
 
 public class Float16VectorExample {
-    private static final String COLLECTION_NAME = "java_sdk_example_float16";
+    private static final String COLLECTION_NAME = "java_sdk_example_float16_vector";
     private static final String ID_FIELD = "id";
     private static final String VECTOR_FIELD = "vector";
     private static final Integer VECTOR_DIM = 128;

+ 1 - 1
examples/main/java/io/milvus/v1/IteratorExample.java

@@ -61,7 +61,7 @@ public class IteratorExample {
         milvusClient = new MilvusServiceClient(connectParam).withRetry(retryParam);
     }
 
-    private static final String COLLECTION_NAME = "test_iterator";
+    private static final String COLLECTION_NAME = "java_sdk_example_iterator";
     private static final String ID_FIELD = "userID";
     private static final String VECTOR_FIELD = "userFace";
     private static final Integer VECTOR_DIM = 8;

+ 1 - 1
examples/main/java/io/milvus/v1/SparseVectorExample.java

@@ -37,7 +37,7 @@ import java.util.*;
 
 
 public class SparseVectorExample {
-    private static final String COLLECTION_NAME = "java_sdk_example_sparse";
+    private static final String COLLECTION_NAME = "java_sdk_example_sparse_vector";
     private static final String ID_FIELD = "id";
     private static final String VECTOR_FIELD = "vector";
 

+ 1 - 1
examples/main/java/io/milvus/v2/Float16VectorExample.java

@@ -27,7 +27,7 @@ import java.util.*;
 
 
 public class Float16VectorExample {
-    private static final String COLLECTION_NAME = "java_sdk_example_float16";
+    private static final String COLLECTION_NAME = "java_sdk_example_float16_vector";
     private static final String ID_FIELD = "id";
     private static final String FP16_VECTOR_FIELD = "fp16_vector";
     private static final String BF16_VECTOR_FIELD = "bf16_vector";

+ 161 - 0
examples/main/java/io/milvus/v2/IteratorExample.java

@@ -0,0 +1,161 @@
+package io.milvus.v2;
+
+import com.google.common.collect.Lists;
+import com.google.gson.Gson;
+import com.google.gson.JsonObject;
+import io.milvus.orm.iterator.QueryIterator;
+import io.milvus.orm.iterator.SearchIterator;
+import io.milvus.response.QueryResultsWrapper;
+import io.milvus.v1.CommonUtils;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.DataType;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.AddFieldReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.vector.request.InsertReq;
+import io.milvus.v2.service.vector.request.QueryIteratorReq;
+import io.milvus.v2.service.vector.request.QueryReq;
+import io.milvus.v2.service.vector.request.SearchIteratorReq;
+import io.milvus.v2.service.vector.request.data.FloatVec;
+import io.milvus.v2.service.vector.response.InsertResp;
+import io.milvus.v2.service.vector.response.QueryResp;
+
+import java.util.*;
+
+public class IteratorExample {
+    private static final String COLLECTION_NAME = "java_sdk_example_iterator";
+    private static final String ID_FIELD = "userID";
+    private static final String AGE_FIELD = "userAge";
+    private static final String VECTOR_FIELD = "userFace";
+    private static final Integer VECTOR_DIM = 128;
+
+    public static void main(String[] args) {
+        ConnectConfig config = ConnectConfig.builder()
+                .uri("http://localhost:19530")
+                .build();
+        MilvusClientV2 client = new MilvusClientV2(config);
+
+        // create collection
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(ID_FIELD)
+                .dataType(DataType.Int64)
+                .isPrimaryKey(Boolean.TRUE)
+                .autoID(Boolean.FALSE)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(AGE_FIELD)
+                .dataType(DataType.Int32)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(VECTOR_FIELD)
+                .dataType(DataType.FloatVector)
+                .dimension(VECTOR_DIM)
+                .build());
+
+        List<IndexParam> indexParams = new ArrayList<>();
+        indexParams.add(IndexParam.builder()
+                .fieldName(VECTOR_FIELD)
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.L2)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexParams)
+                .build();
+        client.createCollection(requestCreate);
+
+        // insert rows
+        long count = 10000;
+        List<JsonObject> rowsData = new ArrayList<>();
+        Random ran = new Random();
+        Gson gson = new Gson();
+        for (long i = 0L; i < count; ++i) {
+            JsonObject row = new JsonObject();
+            row.addProperty(ID_FIELD, i);
+            row.addProperty(AGE_FIELD, ran.nextInt(99));
+            row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
+
+            rowsData.add(row);
+        }
+        InsertResp insertResp = client.insert(InsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(rowsData)
+                .build());
+
+        // check row count
+        QueryResp queryResp = client.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .filter("")
+                .outputFields(Collections.singletonList("count(*)"))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+        System.out.printf("Inserted row count: %d\n", queryResults.size());
+
+        // search iterator
+        SearchIterator searchIterator = client.searchIterator(SearchIteratorReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .outputFields(Lists.newArrayList(AGE_FIELD))
+                .batchSize(50L)
+                .vectorFieldName(VECTOR_FIELD)
+                .vectors(Collections.singletonList(new FloatVec(CommonUtils.generateFloatVector(VECTOR_DIM))))
+                .expr(String.format("%s > 50 && %s < 100", AGE_FIELD, AGE_FIELD))
+                .params("{\"range_filter\": 15.0, \"radius\": 20.0}")
+                .topK(300)
+                .metricType(IndexParam.MetricType.L2)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build());
+
+        int counter = 0;
+        while (true) {
+            List<QueryResultsWrapper.RowRecord> res = searchIterator.next();
+            if (res.isEmpty()) {
+                System.out.println("Search iteration finished, close");
+                searchIterator.close();
+                break;
+            }
+
+            for (QueryResultsWrapper.RowRecord record : res) {
+                System.out.println(record);
+                counter++;
+            }
+        }
+        System.out.println(String.format("%d search results returned\n", counter));
+
+        // query iterator
+        QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .expr(String.format("%s < 300", ID_FIELD))
+                .outputFields(Lists.newArrayList(ID_FIELD, AGE_FIELD))
+                .batchSize(50L)
+                .offset(5)
+                .limit(400)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build());
+
+        counter = 0;
+        while (true) {
+            List<QueryResultsWrapper.RowRecord> res = queryIterator.next();
+            if (res.isEmpty()) {
+                System.out.println("query iteration finished, close");
+                queryIterator.close();
+                break;
+            }
+
+            for (QueryResultsWrapper.RowRecord record : res) {
+                System.out.println(record);
+                counter++;
+            }
+        }
+        System.out.println(String.format("%d query results returned", counter));
+
+        client.dropCollection(DropCollectionReq.builder().collectionName(COLLECTION_NAME).build());
+    }
+}

+ 1 - 3
examples/main/java/io/milvus/v2/SimpleExample.java

@@ -5,8 +5,6 @@ import io.milvus.v2.client.*;
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.collection.request.DropCollectionReq;
-import io.milvus.v2.service.collection.request.GetCollectionStatsReq;
-import io.milvus.v2.service.collection.response.GetCollectionStatsResp;
 import io.milvus.v2.service.vector.request.*;
 import io.milvus.v2.service.vector.request.data.FloatVec;
 import io.milvus.v2.service.vector.response.*;
@@ -21,7 +19,7 @@ public class SimpleExample {
                 .build();
         MilvusClientV2 client = new MilvusClientV2(config);
 
-        String collectionName = "simple_test";
+        String collectionName = "java_sdk_example_simple";
         // drop collection if exists
         client.dropCollection(DropCollectionReq.builder()
                 .collectionName(collectionName)

+ 1 - 1
src/main/java/io/milvus/common/utils/Float16Utils.java

@@ -203,7 +203,7 @@ public class Float16Utils {
     /**
      * Converts a ByteBuffer to a fp16/bf16 vector stored in short array.
      */
-    public static List<Short> BufferToF16Vector(ByteBuffer buf) {
+    public static List<Short> bufferToF16Vector(ByteBuffer buf) {
         buf.rewind(); // reset the read position
         List<Short> vector = new ArrayList<>();
         ShortBuffer sbuf = buf.asShortBuffer();

+ 130 - 0
src/main/java/io/milvus/orm/iterator/IteratorAdapterV2.java

@@ -0,0 +1,130 @@
+package io.milvus.orm.iterator;
+
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
+import io.milvus.exception.ParamException;
+import io.milvus.grpc.DataType;
+import io.milvus.grpc.PlaceholderType;
+import io.milvus.param.MetricType;
+import io.milvus.param.collection.FieldType;
+import io.milvus.param.dml.SearchIteratorParam;
+import io.milvus.param.dml.QueryIteratorParam;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.vector.request.QueryIteratorReq;
+import io.milvus.v2.service.vector.request.SearchIteratorReq;
+import io.milvus.v2.service.vector.request.data.BaseVector;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.SortedMap;
+
+public class IteratorAdapterV2 {
+    public static QueryIteratorParam convertV2Req(QueryIteratorReq queryIteratorReq) {
+        return QueryIteratorParam.newBuilder()
+                .withDatabaseName(queryIteratorReq.getDatabaseName())
+                .withCollectionName(queryIteratorReq.getCollectionName())
+                .withPartitionNames(queryIteratorReq.getPartitionNames())
+                .withExpr(queryIteratorReq.getExpr())
+                .withOutFields(queryIteratorReq.getOutputFields())
+                .withConsistencyLevel(ConsistencyLevelEnum.valueOf(queryIteratorReq.getConsistencyLevel().name()))
+                .withOffset(queryIteratorReq.getOffset())
+                .withLimit(queryIteratorReq.getLimit())
+                .withIgnoreGrowing(queryIteratorReq.isIgnoreGrowing())
+                .withBatchSize(queryIteratorReq.getBatchSize())
+                .build();
+    }
+    public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorReq) {
+        MetricType metricType = MetricType.None;
+        if (searchIteratorReq.getMetricType() != IndexParam.MetricType.INVALID) {
+            metricType = MetricType.valueOf(searchIteratorReq.getMetricType().name());
+        }
+
+        SearchIteratorParam.Builder builder = SearchIteratorParam.newBuilder()
+                .withDatabaseName(searchIteratorReq.getDatabaseName())
+                .withCollectionName(searchIteratorReq.getCollectionName())
+                .withPartitionNames(searchIteratorReq.getPartitionNames())
+                .withVectorFieldName(searchIteratorReq.getVectorFieldName())
+                .withMetricType(metricType)
+                .withTopK(searchIteratorReq.getTopK())
+                .withExpr(searchIteratorReq.getExpr())
+                .withOutFields(searchIteratorReq.getOutputFields())
+                .withRoundDecimal(searchIteratorReq.getRoundDecimal())
+                .withParams(searchIteratorReq.getParams())
+                .withGroupByFieldName(searchIteratorReq.getGroupByFieldName())
+                .withIgnoreGrowing(searchIteratorReq.isIgnoreGrowing())
+                .withBatchSize(searchIteratorReq.getBatchSize());
+
+        if (searchIteratorReq.getConsistencyLevel() != null) {
+            builder.withConsistencyLevel(ConsistencyLevelEnum.valueOf(searchIteratorReq.getConsistencyLevel().name()));
+        }
+
+        List<BaseVector> vectors = searchIteratorReq.getVectors();
+        PlaceholderType plType = vectors.get(0).getPlaceholderType();
+        for (BaseVector vector : vectors) {
+            if (vector.getPlaceholderType() != plType) {
+                throw new ParamException("Different types of target vectors in a search request is not allowed.");
+            }
+        }
+
+        switch (plType) {
+            case FloatVector: {
+                List<List<Float>> data = new ArrayList<>();
+                vectors.forEach(vector->data.add((List<Float>)vector.getData()));
+                builder.withFloatVectors(data);
+                break;
+            }
+            case BinaryVector: {
+                List<ByteBuffer> data = new ArrayList<>();
+                vectors.forEach(vector->data.add((ByteBuffer)vector.getData()));
+                builder.withBinaryVectors(data);
+                break;
+            }
+            case Float16Vector: {
+                List<ByteBuffer> data = new ArrayList<>();
+                vectors.forEach(vector -> data.add((ByteBuffer)vector.getData()));
+                builder.withFloat16Vectors(data);
+                break;
+            }
+            case BFloat16Vector: {
+                List<ByteBuffer> data = new ArrayList<>();
+                vectors.forEach(vector -> data.add((ByteBuffer)vector.getData()));
+                builder.withBFloat16Vectors(data);
+                break;
+            }
+            case SparseFloatVector: {
+                List<SortedMap<Long, Float>> data = new ArrayList<>();
+                vectors.forEach(vector -> data.add((SortedMap<Long, Float>)vector.getData()));
+                builder.withSparseFloatVectors(data);
+                break;
+            }
+            default:
+                throw new ParamException("Unsupported vector type.");
+        }
+
+        return builder.build();
+    }
+
+    public static FieldType convertV2Field(CreateCollectionReq.FieldSchema schema) {
+        FieldType.Builder builder = FieldType.newBuilder()
+                .withName(schema.getName())
+                .withDataType(DataType.valueOf(schema.getDataType().name()))
+                .withPrimaryKey(schema.getIsPrimaryKey())
+                .withAutoID(schema.getAutoID())
+                .withPartitionKey(schema.getIsPartitionKey());
+
+        if (schema.getDimension() != null) {
+            builder.withDimension(schema.getDimension());
+        }
+        if (schema.getMaxLength() != null) {
+            builder.withMaxLength(schema.getMaxLength());
+        }
+        if (schema.getMaxCapacity() != null) {
+            builder.withMaxCapacity(schema.getMaxLength());
+        }
+        if (schema.getElementType() != null) {
+            builder.withElementType(DataType.valueOf(schema.getElementType().name()));
+        }
+        return builder.build();
+    }
+}

+ 21 - 0
src/main/java/io/milvus/orm/iterator/QueryIterator.java

@@ -28,6 +28,8 @@ import io.milvus.param.collection.FieldType;
 import io.milvus.param.dml.QueryIteratorParam;
 import io.milvus.param.dml.QueryParam;
 import io.milvus.response.QueryResultsWrapper;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.vector.request.QueryIteratorReq;
 import io.milvus.v2.utils.RpcUtils;
 import org.apache.commons.lang3.StringUtils;
 
@@ -68,6 +70,25 @@ public class QueryIterator {
         seek();
     }
 
+    public QueryIterator(QueryIteratorReq queryIteratorReq,
+                         MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
+                         CreateCollectionReq.FieldSchema primaryField) {
+        this.iteratorCache = new IteratorCache();
+        this.blockingStub = blockingStub;
+        IteratorAdapterV2 adapter = new IteratorAdapterV2();
+        this.queryIteratorParam = adapter.convertV2Req(queryIteratorReq);
+        this.primaryField = adapter.convertV2Field(primaryField);
+
+
+        this.batchSize = (int) queryIteratorParam.getBatchSize();
+        this.expr = queryIteratorParam.getExpr();
+        this.limit = queryIteratorParam.getLimit();
+        this.offset = queryIteratorParam.getOffset();
+        this.rpcUtils = new RpcUtils();
+
+        seek();
+    }
+
     private void seek() {
         this.cacheIdInUse = NO_CACHE_ID;
         if (offset == 0) {

+ 27 - 9
src/main/java/io/milvus/orm/iterator/SearchIterator.java

@@ -7,10 +7,7 @@ import com.google.common.collect.Lists;
 import io.milvus.common.utils.ExceptionUtils;
 import io.milvus.common.utils.JacksonUtils;
 import io.milvus.exception.ParamException;
-import io.milvus.grpc.DataType;
-import io.milvus.grpc.MilvusServiceGrpc;
-import io.milvus.grpc.SearchRequest;
-import io.milvus.grpc.SearchResults;
+import io.milvus.grpc.*;
 import io.milvus.param.MetricType;
 import io.milvus.param.ParamUtils;
 import io.milvus.param.collection.FieldType;
@@ -18,16 +15,15 @@ import io.milvus.param.dml.SearchIteratorParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.response.QueryResultsWrapper;
 import io.milvus.response.SearchResultsWrapper;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.vector.request.SearchIteratorReq;
 import io.milvus.v2.utils.RpcUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.nio.ByteBuffer;
 import java.text.DecimalFormat;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
+import java.util.*;
 
 import static io.milvus.param.Constant.DEFAULT_SEARCH_EXTENSION_RATE;
 import static io.milvus.param.Constant.EF;
@@ -82,6 +78,28 @@ public class SearchIterator {
         initSearchIterator();
     }
 
+    // to support V2
+    public SearchIterator(SearchIteratorReq searchIteratorReq,
+                          MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
+                          CreateCollectionReq.FieldSchema primaryField) {
+        this.iteratorCache = new IteratorCache();
+        this.blockingStub = blockingStub;
+        IteratorAdapterV2 adapter = new IteratorAdapterV2();
+        this.searchIteratorParam = adapter.convertV2Req(searchIteratorReq);
+        this.primaryField = adapter.convertV2Field(primaryField);
+        this.metricType = this.searchIteratorParam.getMetricType();
+
+        this.batchSize = (int) this.searchIteratorParam.getBatchSize();
+        this.expr = this.searchIteratorParam.getExpr();
+        this.topK = this.searchIteratorParam.getTopK();
+        this.rpcUtils = new RpcUtils();
+
+        initParams();
+        checkForSpecialIndexParam();
+        checkRmRangeSearchParameters();
+        initSearchIterator();
+    }
+
     public List<QueryResultsWrapper.RowRecord> next() {
         // 0. check reached limit
         if (!initSuccess || checkReachedLimit()) {
@@ -178,7 +196,7 @@ public class SearchIterator {
         tailBand = getDistance(lastHit);
         String msg = String.format("set up init parameter for searchIterator width:%s tail_band:%s", width, tailBand);
         logger.debug(msg);
-        System.out.println(msg);
+//        System.out.println(msg);
     }
 
     private void updateFilteredIds(SearchResultsWrapper searchResultsWrapper) {

+ 10 - 0
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -45,6 +45,7 @@ import static io.milvus.grpc.DataType.JSON;
  */
 public class FieldDataWrapper {
     private final FieldData fieldData;
+    private List<?> cacheData = null;
 
     public FieldDataWrapper(@NonNull FieldData fieldData) {
         this.fieldData = fieldData;
@@ -182,6 +183,15 @@ public class FieldDataWrapper {
      * @return <code>List</code>
      */
     public List<?> getFieldData() throws IllegalResponseException {
+        if (cacheData != null) {
+            return cacheData;
+        }
+
+        cacheData = getFieldDataInternal();
+        return cacheData;
+    }
+
+    private List<?> getFieldDataInternal() throws IllegalResponseException {
         DataType dt = fieldData.getType();
         switch (dt) {
             case FloatVector: {

+ 3 - 3
src/main/java/io/milvus/response/QueryResultsWrapper.java

@@ -51,7 +51,7 @@ public class QueryResultsWrapper extends RowRecordWrapper {
         List<FieldData> fields = results.getFieldsDataList();
         for (FieldData field : fields) {
             if (fieldName.compareTo(field.getFieldName()) == 0) {
-                return new FieldDataWrapper(field);
+                return getFieldWrapperInternal(field);
             }
         }
 
@@ -95,7 +95,7 @@ public class QueryResultsWrapper extends RowRecordWrapper {
     public long getRowCount() {
         List<FieldData> fields = results.getFieldsDataList();
         for (FieldData field : fields) {
-            FieldDataWrapper wrapper = new FieldDataWrapper(field);
+            FieldDataWrapper wrapper = getFieldWrapperInternal(field);
             return wrapper.getRowCount();
         }
 
@@ -135,7 +135,7 @@ public class QueryResultsWrapper extends RowRecordWrapper {
          * If the key name is in dynamic field, return the value from the dynamic field.
          * Throws {@link ParamException} if the key name doesn't exist.
          *
-         * @return {@link FieldDataWrapper}
+         * @return {@link Object}
          */
         public Object get(String keyName) throws ParamException {
             if (fieldValues.isEmpty()) {

+ 3 - 3
src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -54,7 +54,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
         List<FieldData> fields = results.getFieldsDataList();
         for (FieldData field : fields) {
             if (fieldName.compareTo(field.getFieldName()) == 0) {
-                return new FieldDataWrapper(field);
+                return getFieldWrapperInternal(field);
             }
         }
 
@@ -195,10 +195,10 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             FieldDataWrapper dynamicField = null;
             for (FieldData field : fields) {
                 if (field.getIsDynamic()) {
-                    dynamicField = new FieldDataWrapper(field);
+                    dynamicField = getFieldWrapperInternal(field);
                 }
                 if (outputKey.equals(field.getFieldName())) {
-                    FieldDataWrapper wrapper = new FieldDataWrapper(field);
+                    FieldDataWrapper wrapper = getFieldWrapperInternal(field);
                     for (int n = 0; n < k; ++n) {
                         if ((offset + n) >= wrapper.getRowCount()) {
                             throw new ParamException("Illegal values length of output fields");

+ 15 - 2
src/main/java/io/milvus/response/basic/RowRecordWrapper.java

@@ -27,11 +27,24 @@ import io.milvus.response.QueryResultsWrapper;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
 
 public abstract class RowRecordWrapper {
+    // a cache for output fields
+    private ConcurrentHashMap<String, FieldDataWrapper> outputFieldsData = new ConcurrentHashMap<>();
 
     public abstract List<QueryResultsWrapper.RowRecord> getRowRecords();
 
+    protected FieldDataWrapper getFieldWrapperInternal(FieldData field) {
+        if (outputFieldsData.containsKey(field.getFieldName())) {
+            return outputFieldsData.get(field.getFieldName());
+        }
+
+        FieldDataWrapper wrapper = new FieldDataWrapper(field);
+        outputFieldsData.put(field.getFieldName(), wrapper);
+        return wrapper;
+    }
+
     /**
      * Get the dynamic field. Only available when a collection's dynamic field is enabled.
      * Throws {@link ParamException} if the dynamic field doesn't exist.
@@ -42,7 +55,7 @@ public abstract class RowRecordWrapper {
         List<FieldData> fields = getFieldDataList();
         for (FieldData field : fields) {
             if (field.getIsDynamic()) {
-                return new FieldDataWrapper(field);
+                return getFieldWrapperInternal(field);
             }
         }
 
@@ -60,7 +73,7 @@ public abstract class RowRecordWrapper {
             boolean isField = false;
             for (FieldData field : getFieldDataList()) {
                 if (outputKey.equals(field.getFieldName())) {
-                    FieldDataWrapper wrapper = new FieldDataWrapper(field);
+                    FieldDataWrapper wrapper = getFieldWrapperInternal(field);
                     if (index < 0 || index >= wrapper.getRowCount()) {
                         throw new ParamException("Index out of range");
                     }

+ 23 - 0
src/main/java/io/milvus/v2/client/MilvusClientV2.java

@@ -21,6 +21,8 @@ package io.milvus.v2.client;
 
 import io.grpc.ManagedChannel;
 import io.milvus.grpc.MilvusServiceGrpc;
+import io.milvus.orm.iterator.QueryIterator;
+import io.milvus.orm.iterator.SearchIterator;
 import io.milvus.v2.service.collection.CollectionService;
 import io.milvus.v2.service.collection.request.*;
 import io.milvus.v2.service.collection.response.DescribeCollectionResp;
@@ -320,6 +322,27 @@ public class MilvusClientV2 {
         return vectorService.hybridSearch(this.blockingStub, request);
     }
 
+    /**
+     * Get queryIterator based on scalar field(s) filtered by boolean expression.
+     * Note that the order of the returned entities cannot be guaranteed.
+     *
+     * @param request {@link QueryIteratorReq}
+     * @return {status:result code,data: QueryIterator}
+     */
+    public QueryIterator queryIterator(QueryIteratorReq request) {
+        return vectorService.queryIterator(this.blockingStub, request);
+    }
+
+    /**
+     * Get searchIterator based on a vector field. Use expression to do filtering before search.
+     *
+     * @param request {@link SearchIteratorReq}
+     * @return {status:result code, data: SearchIterator}
+     */
+    public SearchIterator searchIterator(SearchIteratorReq request) {
+        return vectorService.searchIterator(this.blockingStub, request);
+    }
+
     // Partition Operations
     /**
      * Creates a partition in a collection in Milvus.

+ 18 - 0
src/main/java/io/milvus/v2/service/vector/VectorService.java

@@ -21,11 +21,13 @@ package io.milvus.v2.service.vector;
 
 import io.milvus.exception.ParamException;
 import io.milvus.grpc.*;
+import io.milvus.orm.iterator.*;
 import io.milvus.response.DescCollResponseWrapper;
 import io.milvus.v2.exception.ErrorCode;
 import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.BaseService;
 import io.milvus.v2.service.collection.CollectionService;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.collection.request.DescribeCollectionReq;
 import io.milvus.v2.service.collection.response.DescribeCollectionResp;
 import io.milvus.v2.service.index.IndexService;
@@ -176,6 +178,22 @@ public class VectorService extends BaseService {
                 .build();
     }
 
+    public QueryIterator queryIterator(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
+                                           QueryIteratorReq request) {
+        DescribeCollectionResponse descResp = getCollectionInfo(blockingStub, "", request.getCollectionName());
+        DescribeCollectionResp respR = CollectionService.convertDescCollectionResp(descResp);
+        CreateCollectionReq.FieldSchema pkField = respR.getCollectionSchema().getField(respR.getPrimaryFieldName());
+        return new QueryIterator(request, blockingStub, pkField);
+    }
+
+    public SearchIterator searchIterator(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub,
+                                            SearchIteratorReq request) {
+        DescribeCollectionResponse descResp = getCollectionInfo(blockingStub, "", request.getCollectionName());
+        DescribeCollectionResp respR = CollectionService.convertDescCollectionResp(descResp);
+        CreateCollectionReq.FieldSchema pkField = respR.getCollectionSchema().getField(respR.getPrimaryFieldName());
+        return new SearchIterator(request, blockingStub, pkField);
+    }
+
     public DeleteResp delete(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, DeleteReq request) {
         String title = String.format("DeleteRequest collectionName:%s", request.getCollectionName());
 

+ 32 - 0
src/main/java/io/milvus/v2/service/vector/request/QueryIteratorReq.java

@@ -0,0 +1,32 @@
+package io.milvus.v2.service.vector.request;
+
+import com.google.common.collect.Lists;
+import io.milvus.v2.common.ConsistencyLevel;
+import lombok.Builder;
+import lombok.Data;
+import lombok.experimental.SuperBuilder;
+
+import java.util.List;
+
+@Data
+@SuperBuilder
+public class QueryIteratorReq {
+    private String databaseName;
+    private String collectionName;
+    @Builder.Default
+    private List<String> partitionNames = Lists.newArrayList();
+    @Builder.Default
+    private List<String> outputFields = Lists.newArrayList();
+    @Builder.Default
+    private String expr = "";
+    @Builder.Default
+    private ConsistencyLevel consistencyLevel = ConsistencyLevel.BOUNDED;
+    @Builder.Default
+    private long offset = 0;
+    @Builder.Default
+    private long limit = -1;
+    @Builder.Default
+    private boolean ignoreGrowing = false;
+    @Builder.Default
+    private long batchSize = 1000L;
+}

+ 43 - 0
src/main/java/io/milvus/v2/service/vector/request/SearchIteratorReq.java

@@ -0,0 +1,43 @@
+package io.milvus.v2.service.vector.request;
+
+import com.google.common.collect.Lists;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.vector.request.data.BaseVector;
+import lombok.Builder;
+import lombok.Data;
+import lombok.experimental.SuperBuilder;
+
+import java.util.List;
+
+@Data
+@SuperBuilder
+public class SearchIteratorReq {
+    private String databaseName;
+    private String collectionName;
+    @Builder.Default
+    private List<String> partitionNames = Lists.newArrayList();
+    @Builder.Default
+    private IndexParam.MetricType metricType = IndexParam.MetricType.INVALID;
+    private String vectorFieldName;
+    @Builder.Default
+    private int topK = -1;
+    @Builder.Default
+    private String expr = "";
+    @Builder.Default
+    private List<String> outputFields = Lists.newArrayList();
+    @Builder.Default
+    private List<BaseVector> vectors = Lists.newArrayList();
+    @Builder.Default
+    private int roundDecimal = -1;
+    @Builder.Default
+    private String params = "{}";
+    @Builder.Default
+    private ConsistencyLevel consistencyLevel = ConsistencyLevel.BOUNDED;
+    @Builder.Default
+    private boolean ignoreGrowing = false;
+    @Builder.Default
+    private String groupByFieldName = "";
+    @Builder.Default
+    private long batchSize = 1000L;
+}

+ 6 - 3
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -864,7 +864,7 @@ class MilvusClientDockerTest {
                 .withCollectionName(randomCollectionName)
                 .withMetricType(MetricType.IP)
                 .withTopK(topK)
-                .withVectors(targetVectors)
+                .withSparseFloatVectors(targetVectors)
                 .withVectorFieldName(field2Name)
                 .addOutField(field2Name)
                 .withParams("{\"drop_ratio_search\":0.2}")
@@ -880,11 +880,14 @@ class MilvusClientDockerTest {
             List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
             System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
             System.out.println(scores);
-            Assertions.assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
+            if (targetVectorIDs.get(i) != scores.get(0).getLongID()) {
+                System.out.println(targetVectors.get(i));
+            }
+            Assertions.assertEquals(targetVectorIDs.get(i), scores.get(0).getLongID());
 
             Object v = scores.get(0).get(field2Name);
             SortedMap<Long, Float> sparse = (SortedMap<Long, Float>)v;
-            Assertions.assertTrue(sparse.equals(targetVectors.get(i)));
+            Assertions.assertEquals(sparse, targetVectors.get(i));
             Assertions.assertEquals(targetVectors.get(i).size(), sparse.size());
             for (Long key : sparse.keySet()) {
                 Assertions.assertTrue(targetVectors.get(i).containsKey(key));

+ 213 - 1
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -19,10 +19,15 @@
 
 package io.milvus.v2.client;
 
+import com.google.common.collect.Lists;
 import com.google.gson.*;
 
 import com.google.gson.reflect.TypeToken;
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.common.utils.Float16Utils;
+import io.milvus.orm.iterator.QueryIterator;
+import io.milvus.orm.iterator.SearchIterator;
+import io.milvus.response.QueryResultsWrapper;
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
 import io.milvus.v2.common.IndexParam;
@@ -41,7 +46,6 @@ import io.milvus.v2.service.vector.request.*;
 import io.milvus.v2.service.vector.request.data.*;
 import io.milvus.v2.service.vector.request.ranker.*;
 import io.milvus.v2.service.vector.response.*;
-import io.netty.buffer.ByteBuf;
 import org.apache.commons.text.RandomStringGenerator;
 
 import org.junit.jupiter.api.Assertions;
@@ -1264,4 +1268,212 @@ class MilvusClientV2DockerTest {
                 .build());
         Assertions.assertEquals(1L, insertResp.getInsertCnt());
     }
+
+    @Test
+    public void testIterator() {
+        String randomCollectionName = generator.generate(10);
+        CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("float_vector")
+                .dataType(DataType.FloatVector)
+                .dimension(dimension)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("binary_vector")
+                .dataType(DataType.BinaryVector)
+                .dimension(dimension)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("sparse_vector")
+                .dataType(DataType.SparseFloatVector)
+                .dimension(dimension)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("bfloat16_vector")
+                .dataType(DataType.BFloat16Vector)
+                .dimension(dimension)
+                .build());
+
+        List<IndexParam> indexParams = new ArrayList<>();
+        indexParams.add(IndexParam.builder()
+                .fieldName("float_vector")
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.L2)
+                .build());
+        indexParams.add(IndexParam.builder()
+                .fieldName("binary_vector")
+                .indexType(IndexParam.IndexType.BIN_FLAT)
+                .metricType(IndexParam.MetricType.HAMMING)
+                .build());
+        indexParams.add(IndexParam.builder()
+                .fieldName("sparse_vector")
+                .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
+                .metricType(IndexParam.MetricType.IP)
+                .extraParams(new HashMap<String,Object>(){{put("drop_ratio_build", 0.1);}})
+                .build());
+        indexParams.add(IndexParam.builder()
+                .fieldName("bfloat16_vector")
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexParams)
+                .build();
+        client.createCollection(requestCreate);
+
+        // insert rows
+        long count = 10000;
+        List<JsonObject> data = generateRandomData(collectionSchema, count);
+        InsertResp insertResp = client.insert(InsertReq.builder()
+                .collectionName(randomCollectionName)
+                .data(data)
+                .build());
+        Assertions.assertEquals(count, insertResp.getInsertCnt());
+
+        // get row count
+        long rowCount = getRowCount(randomCollectionName);
+        Assertions.assertEquals(count, rowCount);
+
+        // search iterator
+        SearchIterator searchIterator = client.searchIterator(SearchIteratorReq.builder()
+                .collectionName(randomCollectionName)
+                .outputFields(Lists.newArrayList("*"))
+                .batchSize(20L)
+                .vectorFieldName("float_vector")
+                .vectors(Collections.singletonList(new FloatVec(generateFolatVector())))
+                .expr("int64_field > 500 && int64_field < 1000")
+                .params("{\"range_filter\": 5.0, \"radius\": 50.0}")
+                .topK(1000)
+                .metricType(IndexParam.MetricType.L2)
+                .consistencyLevel(ConsistencyLevel.EVENTUALLY)
+                .build());
+
+        int counter = 0;
+        while (true) {
+            List<QueryResultsWrapper.RowRecord> res = searchIterator.next();
+            if (res.isEmpty()) {
+                System.out.println("search iteration finished, close");
+                searchIterator.close();
+                break;
+            }
+
+            for (QueryResultsWrapper.RowRecord record : res) {
+                Assertions.assertInstanceOf(Float.class, record.get("distance"));
+                Assertions.assertTrue((float)record.get("distance") >= 5.0);
+                Assertions.assertTrue((float)record.get("distance") <= 50.0);
+
+                Assertions.assertInstanceOf(Boolean.class, record.get("bool_field"));
+                Assertions.assertInstanceOf(Integer.class, record.get("int8_field"));
+                Assertions.assertInstanceOf(Integer.class, record.get("int16_field"));
+                Assertions.assertInstanceOf(Integer.class, record.get("int32_field"));
+                Assertions.assertInstanceOf(Long.class, record.get("int64_field"));
+                Assertions.assertInstanceOf(Float.class, record.get("float_field"));
+                Assertions.assertInstanceOf(Double.class, record.get("double_field"));
+                Assertions.assertInstanceOf(String.class, record.get("varchar_field"));
+                Assertions.assertInstanceOf(JsonObject.class, record.get("json_field"));
+                Assertions.assertInstanceOf(List.class, record.get("arr_int_field"));
+                Assertions.assertInstanceOf(List.class, record.get("float_vector"));
+                Assertions.assertInstanceOf(ByteBuffer.class, record.get("binary_vector"));
+                Assertions.assertInstanceOf(ByteBuffer.class, record.get("bfloat16_vector"));
+                Assertions.assertInstanceOf(SortedMap.class, record.get("sparse_vector"));
+
+                long int64Val = (long)record.get("int64_field");
+                Assertions.assertTrue(int64Val > 500L && int64Val < 1000L);
+
+                String varcharVal = (String)record.get("varchar_field");
+                Assertions.assertTrue(varcharVal.startsWith("varchar_"));
+
+                JsonObject jsonObj = (JsonObject)record.get("json_field");
+                Assertions.assertTrue(jsonObj.has(String.format("JSON_%d", int64Val)));
+
+                List<Integer> intArr = (List<Integer>)record.get("arr_int_field");
+                Assertions.assertTrue(intArr.size() <= 50); // max capacity 50 is defined in the baseSchema()
+
+                List<Float> floatVector = (List<Float>)record.get("float_vector");
+                Assertions.assertEquals(dimension, floatVector.size());
+
+                ByteBuffer binaryVector = (ByteBuffer)record.get("binary_vector");
+                Assertions.assertEquals(dimension, binaryVector.limit()*8);
+
+                ByteBuffer bfloat16Vector = (ByteBuffer)record.get("bfloat16_vector");
+                Assertions.assertEquals(dimension*2, bfloat16Vector.limit());
+
+                SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)record.get("sparse_vector");
+                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
+
+                counter++;
+            }
+        }
+        System.out.println(String.format("There are %d items match distance between [5.0, 50.0]", counter));
+
+        // query iterator
+        QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder()
+                .collectionName(randomCollectionName)
+                .expr("int64_field < 300")
+                .outputFields(Lists.newArrayList("*"))
+                .batchSize(50L)
+                .offset(5)
+                .limit(400)
+                .consistencyLevel(ConsistencyLevel.EVENTUALLY)
+                .build());
+
+        counter = 0;
+        while (true) {
+            List<QueryResultsWrapper.RowRecord> res = queryIterator.next();
+            if (res.isEmpty()) {
+                System.out.println("query iteration finished, close");
+                queryIterator.close();
+                break;
+            }
+
+            for (QueryResultsWrapper.RowRecord record : res) {
+                Assertions.assertInstanceOf(Boolean.class, record.get("bool_field"));
+                Assertions.assertInstanceOf(Integer.class, record.get("int8_field"));
+                Assertions.assertInstanceOf(Integer.class, record.get("int16_field"));
+                Assertions.assertInstanceOf(Integer.class, record.get("int32_field"));
+                Assertions.assertInstanceOf(Long.class, record.get("int64_field"));
+                Assertions.assertInstanceOf(Float.class, record.get("float_field"));
+                Assertions.assertInstanceOf(Double.class, record.get("double_field"));
+                Assertions.assertInstanceOf(String.class, record.get("varchar_field"));
+                Assertions.assertInstanceOf(JsonObject.class, record.get("json_field"));
+                Assertions.assertInstanceOf(List.class, record.get("arr_int_field"));
+                Assertions.assertInstanceOf(List.class, record.get("float_vector"));
+                Assertions.assertInstanceOf(ByteBuffer.class, record.get("binary_vector"));
+                Assertions.assertInstanceOf(ByteBuffer.class, record.get("bfloat16_vector"));
+                Assertions.assertInstanceOf(SortedMap.class, record.get("sparse_vector"));
+
+                long int64Val = (long)record.get("int64_field");
+                Assertions.assertTrue(int64Val < 300L);
+
+                String varcharVal = (String)record.get("varchar_field");
+                Assertions.assertTrue(varcharVal.startsWith("varchar_"));
+
+                JsonObject jsonObj = (JsonObject)record.get("json_field");
+                Assertions.assertTrue(jsonObj.has(String.format("JSON_%d", int64Val)));
+
+                List<Integer> intArr = (List<Integer>)record.get("arr_int_field");
+                Assertions.assertTrue(intArr.size() <= 50); // max capacity 50 is defined in the baseSchema()
+
+                List<Float> floatVector = (List<Float>)record.get("float_vector");
+                Assertions.assertEquals(dimension, floatVector.size());
+
+                ByteBuffer binaryVector = (ByteBuffer)record.get("binary_vector");
+                Assertions.assertEquals(dimension, binaryVector.limit()*8);
+
+                ByteBuffer bfloat16Vector = (ByteBuffer)record.get("bfloat16_vector");
+                Assertions.assertEquals(dimension*2, bfloat16Vector.limit());
+
+                SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)record.get("sparse_vector");
+                Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
+
+                counter++;
+            }
+        }
+        Assertions.assertEquals(295, counter);
+
+        client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
+    }
 }