Browse Source

Add test case for all supported indexed (#342)

Signed-off-by: groot <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
a505b3e1b7
1 changed files with 134 additions and 4 deletions
  1. 134 4
      src/test/java/io/milvus/client/MilvusClientDockerTest.java

+ 134 - 4
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -44,10 +44,7 @@ import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
 
 
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Random;
+import java.util.*;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 
 
@@ -1031,4 +1028,137 @@ class MilvusClientDockerTest {
             System.out.println(scores);
             System.out.println(scores);
         }
         }
     }
     }
+
+    private static void testIndex(String collectionName, String fieldName,
+                                  IndexType type, MetricType metric,
+                                  String params, Boolean syncMode) {
+        // create index
+        CreateIndexParam indexParam = CreateIndexParam.newBuilder()
+                .withCollectionName(collectionName)
+                .withFieldName(fieldName)
+                .withIndexName("index")
+                .withIndexType(type)
+                .withMetricType(metric)
+                .withExtraParam(params)
+                .withSyncMode(syncMode)
+                .build();
+
+        R<RpcStatus> createIndexR = client.createIndex(indexParam);
+        assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
+
+        // drop index
+        DropIndexParam dropIndexParam = DropIndexParam.newBuilder()
+                .withCollectionName(collectionName)
+                .withIndexName(indexParam.getIndexName())
+                .build();
+        R<RpcStatus> dropIndexR = client.dropIndex(dropIndexParam);
+        assertEquals(R.Status.Success.getCode(), dropIndexR.getStatus().intValue());
+    }
+
+    @Test
+    void testFloatVectorIndex() {
+        String randomCollectionName = generator.generate(10);
+
+        // collection schema
+        String field1Name = "idg_field";
+        String field2Name = "vec_field";
+        List<FieldType> fieldsSchema = new ArrayList<>();
+        fieldsSchema.add(FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.Int64)
+                .withName(field1Name)
+                .build());
+
+        fieldsSchema.add(FieldType.newBuilder()
+                .withDataType(DataType.FloatVector)
+                .withName(field2Name)
+                .withDimension(dimension)
+                .build());
+
+        // create collection
+        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withDescription("test")
+                .withFieldTypes(fieldsSchema)
+                .build();
+
+        R<RpcStatus> createR = client.createCollection(createParam);
+        assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        // test all supported indexes
+        Map<IndexType, String> indexTypes = new HashMap<>();
+        indexTypes.put(IndexType.FLAT, "{}");
+        indexTypes.put(IndexType.IVF_FLAT, "{\"nlist\":128}");
+        indexTypes.put(IndexType.IVF_SQ8, "{\"nlist\":128}");
+        indexTypes.put(IndexType.IVF_PQ, "{\"nlist\":128, \"m\":16, \"nbits\":8}");
+        indexTypes.put(IndexType.ANNOY, "{\"n_trees\":16}");
+        indexTypes.put(IndexType.HNSW, "{\"M\":16,\"efConstruction\":64}");
+        indexTypes.put(IndexType.RHNSW_FLAT, "{\"M\":16,\"efConstruction\":64}");
+        indexTypes.put(IndexType.RHNSW_PQ, "{\"M\":16,\"efConstruction\":64, \"PQM\":16}");
+        indexTypes.put(IndexType.RHNSW_SQ, "{\"M\":16,\"efConstruction\":64}");
+
+        List<MetricType> metricTypes = new ArrayList<>();
+        metricTypes.add(MetricType.L2);
+        metricTypes.add(MetricType.IP);
+
+        for (IndexType type : indexTypes.keySet()) {
+            for (MetricType metric : metricTypes) {
+                testIndex(randomCollectionName, field2Name, type, metric, indexTypes.get(type), Boolean.TRUE);
+                testIndex(randomCollectionName, field2Name, type, metric, indexTypes.get(type), Boolean.FALSE);
+            }
+        }
+    }
+
+    @Test
+    void testBinaryVectorIndex() {
+        String randomCollectionName = generator.generate(10);
+
+        // collection schema
+        String field1Name = "id_field";
+        String field2Name = "vector_field";
+        FieldType field1 = FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(true)
+                .withDataType(DataType.Int64)
+                .withName(field1Name)
+                .build();
+
+        FieldType field2 = FieldType.newBuilder()
+                .withDataType(DataType.BinaryVector)
+                .withName(field2Name)
+                .withDimension(dimension)
+                .build();
+
+        // create collection
+        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withDescription("test")
+                .addFieldType(field1)
+                .addFieldType(field2)
+                .build();
+
+        R<RpcStatus> createR = client.createCollection(createParam);
+        assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        // test all supported indexes
+        List<MetricType> flatMetricTypes = new ArrayList<>();
+        flatMetricTypes.add(MetricType.SUBSTRUCTURE);
+        flatMetricTypes.add(MetricType.SUPERSTRUCTURE);
+
+        for (MetricType metric : flatMetricTypes) {
+            testIndex(randomCollectionName, field2Name, IndexType.BIN_FLAT, metric, "{}", Boolean.TRUE);
+            testIndex(randomCollectionName, field2Name, IndexType.BIN_FLAT, metric, "{}", Boolean.FALSE);
+        }
+
+        List<MetricType> ivfMetricTypes = new ArrayList<>();
+        ivfMetricTypes.add(MetricType.HAMMING);
+        ivfMetricTypes.add(MetricType.JACCARD);
+        ivfMetricTypes.add(MetricType.TANIMOTO);
+
+        for (MetricType metric : ivfMetricTypes) {
+            testIndex(randomCollectionName, field2Name, IndexType.BIN_IVF_FLAT, metric, "{\"nlist\":128}", Boolean.TRUE);
+            testIndex(randomCollectionName, field2Name, IndexType.BIN_IVF_FLAT, metric, "{\"nlist\":128}", Boolean.FALSE);
+        }
+    }
 }
 }