Explorar o código

add testcases for random sampling (#1333)

Signed-off-by: yongpengli-z <yongpeng.li@zilliz.com>
yongpengli-z hai 1 mes
pai
achega
72ac7eda7d

+ 1 - 1
.github/workflows/java_sdk_ci_test.yaml

@@ -28,7 +28,7 @@ jobs:
         run: |
           echo "build jar"
           git submodule update --init
-          mvn clean versions:set -DnewVersion=2.4.0
+          mvn clean versions:set -DnewVersion=2.5.5
           mvn clean install -Dmaven.test.skip=true
 
 #      - name: Test

+ 1 - 1
tests/milvustestv2/pom.xml

@@ -70,7 +70,7 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>2.4.0</version>
+            <version>2.5.5</version>
         </dependency>
         <dependency>
             <groupId>com.google.protobuf</groupId>

+ 2 - 0
tests/milvustestv2/src/main/java/com/zilliz/milvustestv2/common/BaseTest.java

@@ -5,6 +5,8 @@ import com.zilliz.milvustestv2.Milvustestv2Application;
 import com.zilliz.milvustestv2.config.ConnectInfoConfig;
 import com.zilliz.milvustestv2.params.FieldParam;
 import com.zilliz.milvustestv2.utils.PropertyFilesUtil;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.param.ConnectParam;
 import io.milvus.param.MetricType;
 import io.milvus.v2.client.ConnectConfig;
 import io.milvus.v2.client.MilvusClientV2;

+ 13 - 13
tests/milvustestv2/src/main/java/com/zilliz/milvustestv2/common/CommonFunction.java

@@ -742,11 +742,11 @@ public class CommonFunction {
         for (long i = startId; i < (num + startId); i++) {
             JsonObject row = new JsonObject();
             row.addProperty(CommonData.fieldInt64, i);
-            row.addProperty(CommonData.fieldInt32, (String) null);
-            row.addProperty(CommonData.fieldInt16, (int) i % 32767);
-            row.addProperty(CommonData.fieldInt8, (short) i % 127);
-            row.addProperty(CommonData.fieldBool, i % 2 == 0);
-            if (i % 2 == 1) {
+            if (i % 2 == 0) {
+                row.addProperty(CommonData.fieldInt32, (int) i % 32767);
+                row.addProperty(CommonData.fieldInt16, (int) i % 32767);
+                row.addProperty(CommonData.fieldInt8, (short) i % 127);
+                row.addProperty(CommonData.fieldBool, i % 3 == 0);
                 row.addProperty(CommonData.fieldDouble, (double) i);
                 row.addProperty(CommonData.fieldVarchar, "Str" + i);
                 row.addProperty(CommonData.fieldFloat, (float) i);
@@ -774,15 +774,15 @@ public class CommonFunction {
             }
 
             JsonObject json = new JsonObject();
-            json.addProperty(CommonData.fieldInt64, (int) i % 32767);
-            json.addProperty(CommonData.fieldInt32, (int) i % 32767);
-            json.addProperty(CommonData.fieldDouble, (double) i);
-            json.add(CommonData.fieldArray, gson.toJsonTree(Arrays.asList(i, i + 1, i + 2)));
-            json.addProperty(CommonData.fieldBool, i % 2 == 0);
-            if (i % 2 == 1) {
+            if (i % 2 == 0) {
+                json.addProperty(CommonData.fieldInt64, (int) i % 32767);
+                json.addProperty(CommonData.fieldInt32, (int) i % 32767);
+                json.addProperty(CommonData.fieldDouble, (double) i);
+                json.add(CommonData.fieldArray, gson.toJsonTree(Arrays.asList(i, i + 1, i + 2)));
+                json.addProperty(CommonData.fieldBool, i % 3 == 0);
                 json.addProperty(CommonData.fieldVarchar, "Str" + i);
+                json.addProperty(CommonData.fieldFloat, (float) i);
             }
-            json.addProperty(CommonData.fieldFloat, (float) i);
             row.add(CommonData.fieldJson, json);
             jsonList.add(row);
         }
@@ -1542,7 +1542,7 @@ public class CommonFunction {
         CreateCollectionReq.CollectionSchema collectionSchema = describeCollectionResp.getCollectionSchema();
         RemoteBulkWriter remoteBulkWriter = buildRemoteBulkWriter(collectionSchema, bulkFileType);
         List<JsonObject> jsonObjects = CommonFunction.genCommonData(collection, count);
-        jsonObjects.forEach(x->{
+        jsonObjects.forEach(x -> {
             try {
                 remoteBulkWriter.appendRow(x);
             } catch (IOException | InterruptedException e) {

+ 2 - 1
tests/milvustestv2/src/test/java/com/zilliz/milvustestv2/alias/AlterAliasTest.java

@@ -39,8 +39,9 @@ public class AlterAliasTest extends BaseTest {
     @AfterClass(alwaysRun = true)
     public void cleanTestData(){
         milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(newCollectionName).build());
-        milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(newCollectionName2).build());
+        // remove alias before drop collection
         milvusClientV2.dropAlias(DropAliasReq.builder().alias(aliasName).build());
+        milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(newCollectionName2).build());
     }
 
     @Test(description = "Alter alias test",groups = {"Smoke"})

+ 2 - 0
tests/milvustestv2/src/test/java/com/zilliz/milvustestv2/alias/CreateAliasTest.java

@@ -10,6 +10,7 @@ import io.milvus.v2.common.IndexParam;
 import io.milvus.v2.service.collection.request.DropCollectionReq;
 import io.milvus.v2.service.index.request.CreateIndexReq;
 import io.milvus.v2.service.utility.request.CreateAliasReq;
+import io.milvus.v2.service.utility.request.DropAliasReq;
 import io.milvus.v2.service.utility.request.ListAliasesReq;
 import io.milvus.v2.service.utility.response.ListAliasResp;
 import io.milvus.v2.service.vector.request.InsertReq;
@@ -49,6 +50,7 @@ public class CreateAliasTest extends BaseTest {
 
     @AfterClass(alwaysRun = true)
     public void cleanTestData(){
+        milvusClientV2.dropAlias(DropAliasReq.builder().alias(aliasName).build());
         milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(newCollectionName).build());
     }
 

+ 305 - 26
tests/milvustestv2/src/test/java/com/zilliz/milvustestv2/vectorOperation/QueryTest.java

@@ -7,13 +7,26 @@ import com.zilliz.milvustestv2.common.CommonData;
 import com.zilliz.milvustestv2.common.CommonFunction;
 import com.zilliz.milvustestv2.params.FieldParam;
 import com.zilliz.milvustestv2.utils.DataProviderUtils;
+import com.zilliz.milvustestv2.utils.PropertyFilesUtil;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.grpc.GetPersistentSegmentInfoResponse;
+import io.milvus.grpc.GetQuerySegmentInfoResponse;
+import io.milvus.param.ConnectParam;
+import io.milvus.param.R;
+import io.milvus.param.control.GetPersistentSegmentInfoParam;
+import io.milvus.param.control.GetQuerySegmentInfoParam;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
 import io.milvus.v2.common.IndexParam;
-import io.milvus.v2.service.collection.request.DropCollectionReq;
-import io.milvus.v2.service.collection.request.GetCollectionStatsReq;
-import io.milvus.v2.service.collection.request.LoadCollectionReq;
+import io.milvus.v2.service.collection.request.*;
+import io.milvus.v2.service.collection.response.DescribeCollectionResp;
 import io.milvus.v2.service.collection.response.GetCollectionStatsResp;
+import io.milvus.v2.service.index.request.DescribeIndexReq;
+import io.milvus.v2.service.index.response.DescribeIndexResp;
+import io.milvus.v2.service.utility.request.FlushReq;
+import io.milvus.v2.service.vector.request.GetReq;
 import io.milvus.v2.service.vector.request.InsertReq;
 import io.milvus.v2.service.vector.request.QueryReq;
 import io.milvus.v2.service.vector.response.InsertResp;
@@ -23,10 +36,13 @@ import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
+import org.testng.internal.reflect.MethodMatcherException;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 /**
  * @Author yongpeng.li
@@ -35,28 +51,130 @@ import java.util.List;
 public class QueryTest extends BaseTest {
 
     String nullableDefaultCollectionName;
-    
+    String samplingCollection;
+    String samplingCollectionWithMultiSegment;
+    String samplingCollectionWithNullable;
+    MilvusServiceClient milvusClientV1;
+    long samplingCollectionEntityNum = CommonData.numberEntities * 3;
+    long samplingCollectionWithNullableEntityNum = CommonData.numberEntities * 3;
+    long samplingCollectionWithMultiSegmentEntityNum = CommonData.numberEntities * 20;
+
     @DataProvider(name = "filterAndExcept")
     public Object[][] providerData() {
         return new Object[][]{
+                {CommonData.fieldInt64 + " >= 0 ", CommonData.numberEntities * 3},
                 {CommonData.fieldInt64 + " < 10 ", 10},
-                {CommonData.fieldInt64 + " != 10 ", CommonData.numberEntities*3 - 1},
+                {CommonData.fieldInt64 + " != 10 ", CommonData.numberEntities * 3 - 1},
                 {CommonData.fieldInt64 + " <= 10 ", 11},
                 {"5<" + CommonData.fieldInt64 + " <= 10 ", 5},
-                {CommonData.fieldInt64 + " >= 10 ", CommonData.numberEntities*3 - 10},
-                {CommonData.fieldInt64 + " > 100 ", CommonData.numberEntities*3 - 101},
+                {CommonData.fieldInt64 + " >= 10 ", CommonData.numberEntities * 3 - 10},
+                {CommonData.fieldInt64 + " > 100 ", CommonData.numberEntities * 3 - 101},
                 {CommonData.fieldInt64 + " < 10 and " + CommonData.fieldBool + " == true", 5},
                 {CommonData.fieldInt64 + " in [1,2,3] ", 3},
-                {CommonData.fieldInt64 + " not in [1,2,3] ", CommonData.numberEntities*3 - 3},
+                {CommonData.fieldInt64 + " not in [1,2,3] ", CommonData.numberEntities * 3 - 3},
                 {CommonData.fieldInt64 + " < 10 and " + CommonData.fieldInt32 + " >5 ", 4},
-                {CommonData.fieldVarchar + " > \"0\" ", CommonData.numberEntities*3},
+                {CommonData.fieldVarchar + " > \"0\" ", CommonData.numberEntities * 3},
                 {CommonData.fieldVarchar + " like \"str%\" ", 0},
-                {CommonData.fieldVarchar + " like \"Str%\" ", CommonData.numberEntities*3},
+                {CommonData.fieldVarchar + " like \"Str%\" ", CommonData.numberEntities * 3},
                 {CommonData.fieldVarchar + " like \"Str1\" ", 1},
                 {CommonData.fieldInt8 + " > 129 ", 0},
+                {CommonData.fieldInt32 + " > 0 ", CommonData.numberEntities * 3 - 1},
+                {CommonData.fieldInt32 + " < 10 ", 10},
+                {CommonData.fieldInt32 + " != 10 ", CommonData.numberEntities * 3 - 1},
+                {CommonData.fieldInt32 + " <= 10 ", 11},
+                {"5<" + CommonData.fieldInt32 + " <= 10 ", 5},
+                {CommonData.fieldInt32 + " >= 10 ", CommonData.numberEntities * 3 - 10},
+                {CommonData.fieldInt32 + " > 100 ", CommonData.numberEntities * 3 - 101},
+                {CommonData.fieldInt32 + " < 10 and " + CommonData.fieldBool + " == true", 5},
+                {CommonData.fieldInt32 + " in [1,2,3] ", 3},
+                {CommonData.fieldInt32 + " not in [1,2,3] ", CommonData.numberEntities * 3 - 3},
+                {CommonData.fieldInt32 + " < 10 and " + CommonData.fieldInt32 + " >5 ", 4},
+                {CommonData.fieldFloat + "<= 10", 11},
+                {CommonData.fieldArray + "[0] >= 0", CommonData.numberEntities * 3},
+                {CommonData.fieldArray + "[0] <= 10", 11},
+                {"ARRAY_CONTAINS(" + CommonData.fieldArray + ", 1)", 2},
+                {(CommonData.numberEntities * 3 - 10) + " <= " + CommonData.fieldArray + "[0] ", 10},
+                {CommonData.fieldArray + "[0] < 10 or " + (CommonData.numberEntities * 3 - 10) + " <= " + CommonData.fieldArray + "[0] ", 20},
+                {CommonData.fieldArray + "[0] < 10 ||" + (CommonData.numberEntities * 3 - 10) + " <= " + CommonData.fieldArray + "[0] ", 20},
+                {CommonData.fieldArray + "[0] not in [1,2,3]", CommonData.numberEntities * 3 - 3},
+                {CommonData.fieldArray + "[0]  in [1,2,3]", 3},
+                {CommonData.fieldArray + "[0] != 0", CommonData.numberEntities * 3 - 1},
+                {CommonData.fieldArray + "[1] % 100 == 0 ", (CommonData.numberEntities * 3) / 100},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] < 10", 10},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] % 100 == 0", (CommonData.numberEntities * 3) / 100},
+                {CommonData.fieldJson + "['" + CommonData.fieldFloat + "'] < 1000 && " + CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] >= 500", 500},
+                {CommonData.fieldJson + "['" + CommonData.fieldFloat + "'] < 10", 10},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] in [1,2,3]", 3},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] not in [1,2,3]", CommonData.numberEntities * 3 - 3},
+                {"(" + CommonData.fieldJson + "['" + CommonData.fieldFloat + "'] < 10 || " + CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] >= " + (CommonData.numberEntities * 3 - 10) + ")", 20},
         };
     }
 
+    @DataProvider(name = "filterAndExceptWithNullable")
+    public Object[][] providerDataWithNullable() {
+        return new Object[][]{
+                {CommonData.fieldInt64 + " >= 0 ", CommonData.numberEntities * 3},
+                {CommonData.fieldInt64 + " < 10 ", 10},
+                {CommonData.fieldInt64 + " != 10 ", CommonData.numberEntities * 3 - 1},
+                {CommonData.fieldInt64 + " <= 10 ", 11},
+                {"5<" + CommonData.fieldInt64 + " <= 10 ", 5},
+                {CommonData.fieldInt64 + " >= 10 ", CommonData.numberEntities * 3 - 10},
+                {CommonData.fieldInt64 + " > 100 ", CommonData.numberEntities * 3 - 101},
+                {CommonData.fieldInt64 + " < 10 and " + CommonData.fieldBool + " == true", 2},
+                {CommonData.fieldInt64 + " in [1,2,3] ", 3},
+                {CommonData.fieldInt64 + " not in [1,2,3] ", CommonData.numberEntities * 3 - 3},
+                {CommonData.fieldInt64 + " < 10 and " + CommonData.fieldInt32 + " >5 ", 2},
+                {CommonData.fieldVarchar + " > \"0\" ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldVarchar + " like \"str%\" ", 0},
+                {CommonData.fieldVarchar + " is null ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldVarchar + " is not null ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldVarchar + " like \"Str%\" ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldVarchar + " like \"Str1\" ", 0},
+                {CommonData.fieldInt8 + " > 129 ", 0},
+                {CommonData.fieldInt32 + " > 0 ", (CommonData.numberEntities * 3 / 2) - 1},
+                {CommonData.fieldInt32 + " < 10 ", 5},
+                {CommonData.fieldInt32 + " is null ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldInt32 + " is not null ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldInt32 + " != 10 ", (CommonData.numberEntities * 3 / 2) - 1},
+                {CommonData.fieldInt32 + " <= 10 ", 6},
+                {"5<" + CommonData.fieldInt32 + " <= 10 ", 3},
+                {CommonData.fieldInt32 + " >= 10 ", (CommonData.numberEntities * 3 / 2) - 5},
+                {CommonData.fieldInt32 + " > 100 ", (CommonData.numberEntities * 3 - 101) / 2},
+                {CommonData.fieldInt32 + " < 10 and " + CommonData.fieldBool + " == true", 2},
+                {CommonData.fieldInt32 + " in [1,2,3] ", 1},
+                {CommonData.fieldInt32 + " not in [1,2,3] ", CommonData.numberEntities * 3 / 2 - 1},
+                {CommonData.fieldInt32 + " < 10 and " + CommonData.fieldInt32 + " >5 ", 2},
+                {CommonData.fieldFloat + "<= 10", 6},
+                {CommonData.fieldArray + "[0] >= 0", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldArray + "[0] <= 10", 6},
+                {"ARRAY_CONTAINS(" + CommonData.fieldArray + ", 1)", 1},
+                {CommonData.fieldArray + " is null ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldArray + " is not null ", CommonData.numberEntities * 3 / 2},
+                {(CommonData.numberEntities * 3 - 10) + " <= " + CommonData.fieldArray + "[0] ", 5},
+                {CommonData.fieldArray + "[0] < 10 or " + (CommonData.numberEntities * 3 - 10) + " <= " + CommonData.fieldArray + "[0] ", 10},
+                {CommonData.fieldArray + "[0] < 10 ||" + (CommonData.numberEntities * 3 - 10) + " <= " + CommonData.fieldArray + "[0] ", 10},
+                {CommonData.fieldArray + "[0] not in [1,2,3]", (CommonData.numberEntities * 3 / 2) - 1},
+                {CommonData.fieldArray + "[0]  in [1,2,3]", 1},
+                {CommonData.fieldArray + "[0] != 0", (CommonData.numberEntities * 3 / 2) - 1},
+                {CommonData.fieldArray + "[1] % 100 == 0 ", 0},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] < 10", 5},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] % 100 == 0", CommonData.numberEntities * 3 / 100},
+                {CommonData.fieldJson + "['" + CommonData.fieldFloat + "'] < 1000 && " + CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] >= 500", 250},
+                {CommonData.fieldJson + "['" + CommonData.fieldFloat + "'] < 10", 5},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] in [1,2,3]", 1},
+                {CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] not in [1,2,3]", CommonData.numberEntities * 3 - 1},
+                {"(" + CommonData.fieldJson + "['" + CommonData.fieldFloat + "'] < 10 || " + CommonData.fieldJson + "['" + CommonData.fieldInt64 + "'] >= " + (CommonData.numberEntities * 3 - 10) + ")", 10},
+        };
+    }
+
+    @DataProvider(name = "samplingValue")
+    public Object[][] providerSamplingValue() {
+        return new Object[][]{
+                {0.99}, {0.9}, {0.8}, {0.7}, {0.6}, {0.5}, {0.4}, {0.3}, {0.2}, {0.1}, {0.01}, {0.001}, {1}, {0}, {-0.1}, {2.5}, {65536}
+        };
+    }
+
+
     @DataProvider(name = "queryPartition")
     private Object[][] providePartitionQueryParams() {
         return new Object[][]{
@@ -66,21 +184,21 @@ public class QueryTest extends BaseTest {
                 {Lists.newArrayList(CommonData.partitionNameB), CommonData.numberEntities * 2 + " < " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 3, 0},
                 {Lists.newArrayList(CommonData.partitionNameC), CommonData.numberEntities * 2 + " <= " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 3, CommonData.numberEntities},
                 {Lists.newArrayList(CommonData.partitionNameC), CommonData.numberEntities * 3 + " < " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 4, 0},
-                {Lists.newArrayList(CommonData.partitionNameA,CommonData.partitionNameB), "0 <= " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 2, CommonData.numberEntities * 2},
-                {Lists.newArrayList(CommonData.partitionNameA,CommonData.partitionNameB,CommonData.partitionNameC),"0 <= " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 3, CommonData.numberEntities * 3},
+                {Lists.newArrayList(CommonData.partitionNameA, CommonData.partitionNameB), "0 <= " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 2, CommonData.numberEntities * 2},
+                {Lists.newArrayList(CommonData.partitionNameA, CommonData.partitionNameB, CommonData.partitionNameC), "0 <= " + CommonData.fieldInt64 + " <= " + CommonData.numberEntities * 3, CommonData.numberEntities * 3},
         };
     }
 
-    @DataProvider(name="DiffCollectionWithFilter")
-    public Object[][] providerDiffCollectionWithFilter(){
-        Object[][] vectorType=new Object[][]{
+    @DataProvider(name = "DiffCollectionWithFilter")
+    public Object[][] providerDiffCollectionWithFilter() {
+        Object[][] vectorType = new Object[][]{
                 {CommonData.defaultBFloat16VectorCollection},
                 {CommonData.defaultBinaryVectorCollection},
                 {CommonData.defaultFloat16VectorCollection},
                 {CommonData.defaultSparseFloatVectorCollection}
 
         };
-        Object[][] filter=new Object[][]{
+        Object[][] filter = new Object[][]{
                 {CommonData.fieldInt64 + " < 10 ", 10},
                 {CommonData.fieldInt64 + " != 10 ", CommonData.numberEntities - 1},
                 {CommonData.fieldInt64 + " <= 10 ", 11},
@@ -104,11 +222,15 @@ public class QueryTest extends BaseTest {
 
     @BeforeClass(alwaysRun = true)
     public void providerCollection() {
+        milvusClientV1 = new MilvusServiceClient(ConnectParam.newBuilder()
+                .withUri(System.getProperty("uri") == null ? PropertyFilesUtil.getRunValue("uri") : System.getProperty("uri"))
+                .withToken("root:Milvus")
+                .build());
         nullableDefaultCollectionName = CommonFunction.createNewNullableDefaultValueCollection(CommonData.dim, null, DataType.FloatVector);
         // insert data
-        List<JsonObject> jsonObjects = CommonFunction.generateSimpleNullData(0,CommonData.numberEntities, CommonData.dim,DataType.FloatVector);
+        List<JsonObject> jsonObjects = CommonFunction.generateSimpleNullData(0, CommonData.numberEntities, CommonData.dim, DataType.FloatVector);
         InsertResp insert = milvusClientV2.insert(InsertReq.builder().collectionName(nullableDefaultCollectionName).data(jsonObjects).build());
-        CommonFunction.createVectorIndex(nullableDefaultCollectionName,CommonData.fieldFloatVector, IndexParam.IndexType.AUTOINDEX, IndexParam.MetricType.L2);
+        CommonFunction.createVectorIndex(nullableDefaultCollectionName, CommonData.fieldFloatVector, IndexParam.IndexType.AUTOINDEX, IndexParam.MetricType.L2);
         // Build Scalar Index
         List<FieldParam> FieldParamList = new ArrayList<FieldParam>() {{
             add(FieldParam.builder().fieldName(CommonData.fieldVarchar).indextype(IndexParam.IndexType.BITMAP).build());
@@ -122,24 +244,42 @@ public class QueryTest extends BaseTest {
         CommonFunction.createScalarCommonIndex(nullableDefaultCollectionName, FieldParamList);
         milvusClientV2.loadCollection(LoadCollectionReq.builder().collectionName(nullableDefaultCollectionName).build());
         // create partition
-        CommonFunction.createPartition(nullableDefaultCollectionName,CommonData.partitionNameA);
-        List<JsonObject> jsonObjectsA = CommonFunction.generateSimpleNullData(0,CommonData.numberEntities*3, CommonData.dim,DataType.FloatVector);
+        CommonFunction.createPartition(nullableDefaultCollectionName, CommonData.partitionNameA);
+        List<JsonObject> jsonObjectsA = CommonFunction.generateSimpleNullData(0, CommonData.numberEntities * 3, CommonData.dim, DataType.FloatVector);
         milvusClientV2.insert(InsertReq.builder().collectionName(nullableDefaultCollectionName).partitionName(CommonData.partitionNameA).data(jsonObjectsA).build());
+        // create samplingCollection
+        samplingCollection = CommonFunction.createNewCollection(CommonData.dim, samplingCollection, DataType.FloatVector);
+        CommonFunction.createIndexAndInsertAndLoad(samplingCollection, DataType.FloatVector, true, samplingCollectionEntityNum);
+//        milvusClientV2.flush(FlushReq.builder().collectionNames(Lists.newArrayList(samplingCollection)).waitFlushedTimeoutMs(5000L).build());
+        // create samplingCollection
+       /* samplingCollectionWithMultiSegment = CommonFunction.createNewCollection(CommonData.dim, samplingCollectionWithMultiSegment, DataType.FloatVector);
+        CommonFunction.createIndexAndInsertAndLoad(samplingCollection, DataType.FloatVector, true, samplingCollectionWithMultiSegmentEntityNum);
+        milvusClientV2.flush(FlushReq.builder().collectionNames(Lists.newArrayList(samplingCollectionWithMultiSegment)).waitFlushedTimeoutMs(5000L).build());*/
+        // create samplingNullableCollection
+        samplingCollectionWithNullable = CommonFunction.createNewNullableCollection(CommonData.dim, null, DataType.FloatVector);
+        List<JsonObject> generateSimpleNullData = CommonFunction.generateSimpleNullData(0, samplingCollectionWithNullableEntityNum, CommonData.dim, DataType.FloatVector);
+        milvusClientV2.insert(InsertReq.builder().collectionName(samplingCollectionWithNullable).data(generateSimpleNullData).build());
+        CommonFunction.createVectorIndex(samplingCollectionWithNullable, CommonData.fieldFloatVector, IndexParam.IndexType.AUTOINDEX, IndexParam.MetricType.L2);
+        milvusClientV2.loadCollection(LoadCollectionReq.builder().collectionName(samplingCollectionWithNullable).build());
+
     }
 
     @AfterClass(alwaysRun = true)
     public void cleanTestData() {
         milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(nullableDefaultCollectionName).build());
+        milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(samplingCollection).build());
+        milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(samplingCollectionWithNullable).build());
+//        milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(samplingCollectionWithMultiSegment).build());
     }
 
     @DataProvider(name = "queryNullableField")
     private Object[][] provideNullableFieldQueryParams() {
         return new Object[][]{
-                {CommonData.fieldInt32 + " == 1 ", CommonData.numberEntities*3},
-                {CommonData.fieldDouble + " > 1 ", CommonData.numberEntities*3/2 - 1},
-                {CommonData.fieldVarchar + " == \"1.0\" ", CommonData.numberEntities*3/2},
-                {CommonData.fieldFloat + " == 1.0 ", CommonData.numberEntities*3/2 + 1},
-                {"fieldJson[\"" + CommonData.fieldVarchar + "\"] in [\"Str1\", \"Str3\"]", 2},
+                {CommonData.fieldInt32 + " == 1 ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldDouble + " > 1 ", CommonData.numberEntities * 3 / 2 - 1},
+                {CommonData.fieldVarchar + " == \"1.0\" ", CommonData.numberEntities * 3 / 2},
+                {CommonData.fieldFloat + " == 1.0 ", CommonData.numberEntities * 3 / 2},
+                {"fieldJson[\"" + CommonData.fieldVarchar + "\"] in [\"Str1\", \"Str3\"]", 0},
                 {"ARRAY_CONTAINS(" + CommonData.fieldArray + ", 1)", 1},
         };
     }
@@ -209,7 +349,7 @@ public class QueryTest extends BaseTest {
     }
 
     @Test(description = "query with different collection", groups = {"Smoke"}, dataProvider = "DiffCollectionWithFilter")
-    public void queryDiffCollection(String collectionName,String filter, long expect) {
+    public void queryDiffCollection(String collectionName, String filter, long expect) {
         QueryResp query = milvusClientV2.query(QueryReq.builder()
                 .collectionName(collectionName)
                 .filter(filter)
@@ -229,4 +369,143 @@ public class QueryTest extends BaseTest {
                 .build());
         Assert.assertEquals(query.getQueryResults().size(), expect);
     }
+
+    @Test(description = "sampling test", groups = {"Smoke"}, dataProvider = "filterAndExcept")
+    public void samplingTest(String filter, long expect) {
+//        // 查询collection的segment数量
+//        int segmentSize = 0;
+//        do {
+//            R<GetQuerySegmentInfoResponse> querySegmentInfo =
+//                    milvusClientV1.getQuerySegmentInfo(GetQuerySegmentInfoParam.newBuilder().withCollectionName(samplingCollection).build());
+//            segmentSize = querySegmentInfo.getData().getInfosList().size();
+//        } while (segmentSize == 0);
+        double samplingRate = 0.1;
+        String samplingFilter = "(" + filter + " )&& random_sample(" + samplingRate + ")";
+        long samplingExpect;
+        System.out.println("expect*samplingRate:" + expect * samplingRate);
+        if (expect * samplingRate <= 1) {
+            samplingExpect = 1;
+        } else {
+            samplingExpect = (long) (expect * samplingRate);
+        }
+        QueryResp query = milvusClientV2.query(QueryReq.builder()
+                .collectionName(samplingCollection)
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .outputFields(Lists.newArrayList("*"))
+                .filter(samplingFilter)
+                .build());
+        System.out.println("expect:" + samplingExpect);
+        System.out.println("actual:" + query.getQueryResults().size());
+        Assert.assertTrue(query.getQueryResults().size() <= samplingExpect);
+        Assert.assertTrue(query.getQueryResults().size() >= (samplingExpect - 1));
+    }
+
+    @Test(description = "sampling test with nullable value", groups = {"Smoke"}, dataProvider = "filterAndExceptWithNullable")
+    public void samplingTestWithNullable(String filter, long expect) {
+        double samplingRate = 0.1;
+        String samplingFilter = "(" + filter + " )&& random_sample(" + samplingRate + ")";
+        long samplingExpect;
+        System.out.println("expect*samplingRate:" + expect * samplingRate);
+        if (expect * samplingRate <= 1) {
+            samplingExpect = 1;
+        } else {
+            samplingExpect = (long) (expect * samplingRate);
+        }
+        QueryResp query = milvusClientV2.query(QueryReq.builder()
+                .collectionName(samplingCollectionWithNullable)
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .outputFields(Lists.newArrayList("*"))
+                .filter(samplingFilter)
+                .build());
+        System.out.println("expect:" + samplingExpect);
+        System.out.println("actual:" + query.getQueryResults().size());
+//        Assert.assertEquals(expect, query.getQueryResults().size());
+        Assert.assertTrue(query.getQueryResults().size() <= samplingExpect);
+        Assert.assertTrue(query.getQueryResults().size() >= (samplingExpect - 1));
+    }
+
+    @Test(description = "sampling test with limit", groups = {"Smoke"}, dataProvider = "filterAndExcept")
+    public void samplingWithLimitTest(String filter, long expect) {
+//        // 查询collection的segment数量
+//        int segmentSize = 0;
+//        do {
+//            R<GetQuerySegmentInfoResponse> querySegmentInfo =
+//                    milvusClientV1.getQuerySegmentInfo(GetQuerySegmentInfoParam.newBuilder().withCollectionName(samplingCollection).build());
+//            segmentSize = querySegmentInfo.getData().getInfosList().size();
+//        } while (segmentSize == 0);
+
+        QueryResp query = milvusClientV2.query(QueryReq.builder()
+                .collectionName(samplingCollection)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .outputFields(Lists.newArrayList("*"))
+                .filter(filter)
+                .limit(5)
+                .build());
+        System.out.println("expect:" + expect);
+        System.out.println("actual:" + query.getQueryResults().size());
+        expect = expect > 5 ? 5 : expect;
+        Assert.assertTrue(query.getQueryResults().size() <= expect);
+    }
+
+    @Test(description = "query filter with different sampling value", groups = {"L1"}, dataProvider = "samplingValue")
+    public void queryByFilterWithDifferentValue(double samplingValue) {
+        String filter = CommonData.fieldInt64 + " >= 0 && random_sample(" + samplingValue + ") ";
+        try {
+            QueryResp query = milvusClientV2.query(QueryReq.builder()
+                    .collectionName(samplingCollection)
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .outputFields(Lists.newArrayList("*"))
+                    .filter(filter)
+                    .build());
+            Assert.assertEquals(query.getQueryResults().size(), samplingCollectionEntityNum * samplingValue);
+        } catch (MethodMatcherException e) {
+            System.out.println("MethodMatcherException:" + e.getMessage());
+        } catch (Exception e) {
+            Assert.assertTrue(samplingValue >= 1 || samplingValue <= 0);
+            Assert.assertTrue(e.getMessage().contains("should be between 0 and 1 and not too close to 0 or 1"));
+        }
+    }
+
+    @Test(description = "query ids with different sampling value", groups = {"L1"})
+    public void queryInIdsWithDifferentValue() {
+        String filter = "random_sample(0.1) ";
+        List<Object> ids = IntStream.range(0, 100)
+                .boxed()
+                .collect(Collectors.toList());
+        try {
+            QueryResp query = milvusClientV2.query(QueryReq.builder()
+                    .collectionName(samplingCollection)
+                    .consistencyLevel(ConsistencyLevel.STRONG)
+                    .outputFields(Lists.newArrayList("*"))
+                    .ids(ids)
+                    .filter(filter)
+                    .build());
+        } catch (Exception e) {
+            Assert.assertTrue(e.getMessage().contains("can't be set at the same time"));
+        }
+    }
+
+
+    @Test(description = "sampling test with multi segment", groups = {"L2"}, dataProvider = "samplingValue")
+    public void samplingWithMultiSegment(double samplingValue) {
+
+        // 查询collection的segment数量
+        int segmentSize = 0;
+        do {
+            R<GetQuerySegmentInfoResponse> querySegmentInfo =
+                    milvusClientV1.getQuerySegmentInfo(GetQuerySegmentInfoParam.newBuilder().withCollectionName(samplingCollectionWithMultiSegment).build());
+            segmentSize = querySegmentInfo.getData().getInfosList().size();
+        } while (segmentSize == 0);
+        System.out.println("segmentSize: " + segmentSize);
+        String filter = CommonData.fieldInt64 + " >= 0 && random_sample(" + samplingCollectionWithMultiSegmentEntityNum + ") ";
+        QueryResp query = milvusClientV2.query(QueryReq.builder()
+                .collectionName(samplingCollectionWithMultiSegment)
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .outputFields(Lists.newArrayList("*"))
+                .filter(filter)
+                .build());
+        Assert.assertEquals(query.getQueryResults().size(), samplingCollectionWithMultiSegmentEntityNum * samplingValue);
+    }
+
+
 }

+ 1 - 1
tests/milvustestv2/src/test/java/com/zilliz/milvustestv2/vectorOperation/SearchTest.java

@@ -126,7 +126,7 @@ public class SearchTest extends BaseTest {
                 {CommonData.fieldDouble + " > 1 ", topK},
                 {CommonData.fieldVarchar + " == \"1.0\" ", topK},
                 {CommonData.fieldFloat + " == 1.0 ", topK},
-                {"fieldJson[\"" + CommonData.fieldVarchar + "\"] in [\"Str1\", \"Str3\"]", 2},
+                {"fieldJson[\"" + CommonData.fieldVarchar + "\"] in [\"Str1\", \"Str3\"]", 0},
                 {"ARRAY_CONTAINS(" + CommonData.fieldArray + ", 1)", 1},
         };
     }

+ 1 - 1
tests/milvustestv2/src/test/java/com/zilliz/milvustestv2/vectorOperation/UpsertTest.java

@@ -139,7 +139,7 @@ public class UpsertTest extends BaseTest {
         // query
         QueryResp query = milvusClientV2.query(QueryReq.builder()
                 .collectionName(collectionName)
-                .filter(CommonData.fieldInt32 + " == 1")
+                .filter(CommonData.fieldInt32 + " == 0")
                 .partitionNames(Lists.newArrayList(CommonData.defaultPartitionName))
                 .outputFields(Lists.newArrayList(CommonData.fieldInt64, CommonData.fieldInt32))
                 .consistencyLevel(ConsistencyLevel.STRONG).build());

+ 1 - 1
tests/milvustestv2/src/test/resources/run.properties

@@ -1,3 +1,3 @@
 uri=http://127.0.0.1:19530
-
+minio=http://127.0.0.1:9000