|
@@ -19,10 +19,15 @@
|
|
|
|
|
|
package io.milvus.v2.client;
|
|
package io.milvus.v2.client;
|
|
|
|
|
|
|
|
+import com.google.common.collect.Lists;
|
|
import com.google.gson.*;
|
|
import com.google.gson.*;
|
|
|
|
|
|
import com.google.gson.reflect.TypeToken;
|
|
import com.google.gson.reflect.TypeToken;
|
|
|
|
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
|
|
import io.milvus.common.utils.Float16Utils;
|
|
import io.milvus.common.utils.Float16Utils;
|
|
|
|
+import io.milvus.orm.iterator.QueryIterator;
|
|
|
|
+import io.milvus.orm.iterator.SearchIterator;
|
|
|
|
+import io.milvus.response.QueryResultsWrapper;
|
|
import io.milvus.v2.common.ConsistencyLevel;
|
|
import io.milvus.v2.common.ConsistencyLevel;
|
|
import io.milvus.v2.common.DataType;
|
|
import io.milvus.v2.common.DataType;
|
|
import io.milvus.v2.common.IndexParam;
|
|
import io.milvus.v2.common.IndexParam;
|
|
@@ -41,7 +46,6 @@ import io.milvus.v2.service.vector.request.*;
|
|
import io.milvus.v2.service.vector.request.data.*;
|
|
import io.milvus.v2.service.vector.request.data.*;
|
|
import io.milvus.v2.service.vector.request.ranker.*;
|
|
import io.milvus.v2.service.vector.request.ranker.*;
|
|
import io.milvus.v2.service.vector.response.*;
|
|
import io.milvus.v2.service.vector.response.*;
|
|
-import io.netty.buffer.ByteBuf;
|
|
|
|
import org.apache.commons.text.RandomStringGenerator;
|
|
import org.apache.commons.text.RandomStringGenerator;
|
|
|
|
|
|
import org.junit.jupiter.api.Assertions;
|
|
import org.junit.jupiter.api.Assertions;
|
|
@@ -1264,4 +1268,212 @@ class MilvusClientV2DockerTest {
|
|
.build());
|
|
.build());
|
|
Assertions.assertEquals(1L, insertResp.getInsertCnt());
|
|
Assertions.assertEquals(1L, insertResp.getInsertCnt());
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void testIterator() {
|
|
|
|
+ String randomCollectionName = generator.generate(10);
|
|
|
|
+ CreateCollectionReq.CollectionSchema collectionSchema = baseSchema();
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("float_vector")
|
|
|
|
+ .dataType(DataType.FloatVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("binary_vector")
|
|
|
|
+ .dataType(DataType.BinaryVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("sparse_vector")
|
|
|
|
+ .dataType(DataType.SparseFloatVector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+ collectionSchema.addField(AddFieldReq.builder()
|
|
|
|
+ .fieldName("bfloat16_vector")
|
|
|
|
+ .dataType(DataType.BFloat16Vector)
|
|
|
|
+ .dimension(dimension)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ List<IndexParam> indexParams = new ArrayList<>();
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("float_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.FLAT)
|
|
|
|
+ .metricType(IndexParam.MetricType.L2)
|
|
|
|
+ .build());
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("binary_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.BIN_FLAT)
|
|
|
|
+ .metricType(IndexParam.MetricType.HAMMING)
|
|
|
|
+ .build());
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("sparse_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
|
|
|
|
+ .metricType(IndexParam.MetricType.IP)
|
|
|
|
+ .extraParams(new HashMap<String,Object>(){{put("drop_ratio_build", 0.1);}})
|
|
|
|
+ .build());
|
|
|
|
+ indexParams.add(IndexParam.builder()
|
|
|
|
+ .fieldName("bfloat16_vector")
|
|
|
|
+ .indexType(IndexParam.IndexType.FLAT)
|
|
|
|
+ .metricType(IndexParam.MetricType.COSINE)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ CreateCollectionReq requestCreate = CreateCollectionReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .collectionSchema(collectionSchema)
|
|
|
|
+ .indexParams(indexParams)
|
|
|
|
+ .build();
|
|
|
|
+ client.createCollection(requestCreate);
|
|
|
|
+
|
|
|
|
+ // insert rows
|
|
|
|
+ long count = 10000;
|
|
|
|
+ List<JsonObject> data = generateRandomData(collectionSchema, count);
|
|
|
|
+ InsertResp insertResp = client.insert(InsertReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .data(data)
|
|
|
|
+ .build());
|
|
|
|
+ Assertions.assertEquals(count, insertResp.getInsertCnt());
|
|
|
|
+
|
|
|
|
+ // get row count
|
|
|
|
+ long rowCount = getRowCount(randomCollectionName);
|
|
|
|
+ Assertions.assertEquals(count, rowCount);
|
|
|
|
+
|
|
|
|
+ // search iterator
|
|
|
|
+ SearchIterator searchIterator = client.searchIterator(SearchIteratorReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .outputFields(Lists.newArrayList("*"))
|
|
|
|
+ .batchSize(20L)
|
|
|
|
+ .vectorFieldName("float_vector")
|
|
|
|
+ .vectors(Collections.singletonList(new FloatVec(generateFolatVector())))
|
|
|
|
+ .expr("int64_field > 500 && int64_field < 1000")
|
|
|
|
+ .params("{\"range_filter\": 5.0, \"radius\": 50.0}")
|
|
|
|
+ .topK(1000)
|
|
|
|
+ .metricType(IndexParam.MetricType.L2)
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.EVENTUALLY)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ int counter = 0;
|
|
|
|
+ while (true) {
|
|
|
|
+ List<QueryResultsWrapper.RowRecord> res = searchIterator.next();
|
|
|
|
+ if (res.isEmpty()) {
|
|
|
|
+ System.out.println("search iteration finished, close");
|
|
|
|
+ searchIterator.close();
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for (QueryResultsWrapper.RowRecord record : res) {
|
|
|
|
+ Assertions.assertInstanceOf(Float.class, record.get("distance"));
|
|
|
|
+ Assertions.assertTrue((float)record.get("distance") >= 5.0);
|
|
|
|
+ Assertions.assertTrue((float)record.get("distance") <= 50.0);
|
|
|
|
+
|
|
|
|
+ Assertions.assertInstanceOf(Boolean.class, record.get("bool_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Integer.class, record.get("int8_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Integer.class, record.get("int16_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Integer.class, record.get("int32_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Long.class, record.get("int64_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Float.class, record.get("float_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Double.class, record.get("double_field"));
|
|
|
|
+ Assertions.assertInstanceOf(String.class, record.get("varchar_field"));
|
|
|
|
+ Assertions.assertInstanceOf(JsonObject.class, record.get("json_field"));
|
|
|
|
+ Assertions.assertInstanceOf(List.class, record.get("arr_int_field"));
|
|
|
|
+ Assertions.assertInstanceOf(List.class, record.get("float_vector"));
|
|
|
|
+ Assertions.assertInstanceOf(ByteBuffer.class, record.get("binary_vector"));
|
|
|
|
+ Assertions.assertInstanceOf(ByteBuffer.class, record.get("bfloat16_vector"));
|
|
|
|
+ Assertions.assertInstanceOf(SortedMap.class, record.get("sparse_vector"));
|
|
|
|
+
|
|
|
|
+ long int64Val = (long)record.get("int64_field");
|
|
|
|
+ Assertions.assertTrue(int64Val > 500L && int64Val < 1000L);
|
|
|
|
+
|
|
|
|
+ String varcharVal = (String)record.get("varchar_field");
|
|
|
|
+ Assertions.assertTrue(varcharVal.startsWith("varchar_"));
|
|
|
|
+
|
|
|
|
+ JsonObject jsonObj = (JsonObject)record.get("json_field");
|
|
|
|
+ Assertions.assertTrue(jsonObj.has(String.format("JSON_%d", int64Val)));
|
|
|
|
+
|
|
|
|
+ List<Integer> intArr = (List<Integer>)record.get("arr_int_field");
|
|
|
|
+ Assertions.assertTrue(intArr.size() <= 50); // max capacity 50 is defined in the baseSchema()
|
|
|
|
+
|
|
|
|
+ List<Float> floatVector = (List<Float>)record.get("float_vector");
|
|
|
|
+ Assertions.assertEquals(dimension, floatVector.size());
|
|
|
|
+
|
|
|
|
+ ByteBuffer binaryVector = (ByteBuffer)record.get("binary_vector");
|
|
|
|
+ Assertions.assertEquals(dimension, binaryVector.limit()*8);
|
|
|
|
+
|
|
|
|
+ ByteBuffer bfloat16Vector = (ByteBuffer)record.get("bfloat16_vector");
|
|
|
|
+ Assertions.assertEquals(dimension*2, bfloat16Vector.limit());
|
|
|
|
+
|
|
|
|
+ SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)record.get("sparse_vector");
|
|
|
|
+ Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
|
|
|
|
+
|
|
|
|
+ counter++;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ System.out.println(String.format("There are %d items match distance between [5.0, 50.0]", counter));
|
|
|
|
+
|
|
|
|
+ // query iterator
|
|
|
|
+ QueryIterator queryIterator = client.queryIterator(QueryIteratorReq.builder()
|
|
|
|
+ .collectionName(randomCollectionName)
|
|
|
|
+ .expr("int64_field < 300")
|
|
|
|
+ .outputFields(Lists.newArrayList("*"))
|
|
|
|
+ .batchSize(50L)
|
|
|
|
+ .offset(5)
|
|
|
|
+ .limit(400)
|
|
|
|
+ .consistencyLevel(ConsistencyLevel.EVENTUALLY)
|
|
|
|
+ .build());
|
|
|
|
+
|
|
|
|
+ counter = 0;
|
|
|
|
+ while (true) {
|
|
|
|
+ List<QueryResultsWrapper.RowRecord> res = queryIterator.next();
|
|
|
|
+ if (res.isEmpty()) {
|
|
|
|
+ System.out.println("query iteration finished, close");
|
|
|
|
+ queryIterator.close();
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for (QueryResultsWrapper.RowRecord record : res) {
|
|
|
|
+ Assertions.assertInstanceOf(Boolean.class, record.get("bool_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Integer.class, record.get("int8_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Integer.class, record.get("int16_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Integer.class, record.get("int32_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Long.class, record.get("int64_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Float.class, record.get("float_field"));
|
|
|
|
+ Assertions.assertInstanceOf(Double.class, record.get("double_field"));
|
|
|
|
+ Assertions.assertInstanceOf(String.class, record.get("varchar_field"));
|
|
|
|
+ Assertions.assertInstanceOf(JsonObject.class, record.get("json_field"));
|
|
|
|
+ Assertions.assertInstanceOf(List.class, record.get("arr_int_field"));
|
|
|
|
+ Assertions.assertInstanceOf(List.class, record.get("float_vector"));
|
|
|
|
+ Assertions.assertInstanceOf(ByteBuffer.class, record.get("binary_vector"));
|
|
|
|
+ Assertions.assertInstanceOf(ByteBuffer.class, record.get("bfloat16_vector"));
|
|
|
|
+ Assertions.assertInstanceOf(SortedMap.class, record.get("sparse_vector"));
|
|
|
|
+
|
|
|
|
+ long int64Val = (long)record.get("int64_field");
|
|
|
|
+ Assertions.assertTrue(int64Val < 300L);
|
|
|
|
+
|
|
|
|
+ String varcharVal = (String)record.get("varchar_field");
|
|
|
|
+ Assertions.assertTrue(varcharVal.startsWith("varchar_"));
|
|
|
|
+
|
|
|
|
+ JsonObject jsonObj = (JsonObject)record.get("json_field");
|
|
|
|
+ Assertions.assertTrue(jsonObj.has(String.format("JSON_%d", int64Val)));
|
|
|
|
+
|
|
|
|
+ List<Integer> intArr = (List<Integer>)record.get("arr_int_field");
|
|
|
|
+ Assertions.assertTrue(intArr.size() <= 50); // max capacity 50 is defined in the baseSchema()
|
|
|
|
+
|
|
|
|
+ List<Float> floatVector = (List<Float>)record.get("float_vector");
|
|
|
|
+ Assertions.assertEquals(dimension, floatVector.size());
|
|
|
|
+
|
|
|
|
+ ByteBuffer binaryVector = (ByteBuffer)record.get("binary_vector");
|
|
|
|
+ Assertions.assertEquals(dimension, binaryVector.limit()*8);
|
|
|
|
+
|
|
|
|
+ ByteBuffer bfloat16Vector = (ByteBuffer)record.get("bfloat16_vector");
|
|
|
|
+ Assertions.assertEquals(dimension*2, bfloat16Vector.limit());
|
|
|
|
+
|
|
|
|
+ SortedMap<Long, Float> sparseVector = (SortedMap<Long, Float>)record.get("sparse_vector");
|
|
|
|
+ Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector()
|
|
|
|
+
|
|
|
|
+ counter++;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ Assertions.assertEquals(295, counter);
|
|
|
|
+
|
|
|
|
+ client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());
|
|
|
|
+ }
|
|
}
|
|
}
|