Browse Source

Add unittest cases (#234)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 years ago
parent
commit
383792cc5b

+ 2 - 1
src/main/java/io/milvus/Response/GetCollStatResponseWrapper.java

@@ -2,6 +2,7 @@ package io.milvus.Response;
 
 
 import io.milvus.grpc.GetCollectionStatisticsResponse;
 import io.milvus.grpc.GetCollectionStatisticsResponse;
 import io.milvus.grpc.KeyValuePair;
 import io.milvus.grpc.KeyValuePair;
+import io.milvus.param.Constant;
 import lombok.NonNull;
 import lombok.NonNull;
 
 
 import java.util.List;
 import java.util.List;
@@ -25,7 +26,7 @@ public class GetCollStatResponseWrapper {
     public long getRowCount() throws NumberFormatException {
     public long getRowCount() throws NumberFormatException {
         List<KeyValuePair> stats = stat.getStatsList();
         List<KeyValuePair> stats = stat.getStatsList();
         for (KeyValuePair kv : stats) {
         for (KeyValuePair kv : stats) {
-            if (kv.getKey().compareTo("row_count") == 0) {
+            if (kv.getKey().compareTo(Constant.ROW_COUNT) == 0) {
                 return Long.parseLong(kv.getValue());
                 return Long.parseLong(kv.getValue());
             }
             }
         }
         }

+ 2 - 1
src/main/java/io/milvus/Response/GetPartStatResponseWrapper.java

@@ -2,6 +2,7 @@ package io.milvus.Response;
 
 
 import io.milvus.grpc.GetPartitionStatisticsResponse;
 import io.milvus.grpc.GetPartitionStatisticsResponse;
 import io.milvus.grpc.KeyValuePair;
 import io.milvus.grpc.KeyValuePair;
+import io.milvus.param.Constant;
 import lombok.NonNull;
 import lombok.NonNull;
 
 
 import java.util.List;
 import java.util.List;
@@ -25,7 +26,7 @@ public class GetPartStatResponseWrapper {
     public long getRowCount() throws NumberFormatException {
     public long getRowCount() throws NumberFormatException {
         List<KeyValuePair> stats = stat.getStatsList();
         List<KeyValuePair> stats = stat.getStatsList();
         for (KeyValuePair kv : stats) {
         for (KeyValuePair kv : stats) {
-            if (kv.getKey().compareTo("row_count") == 0) {
+            if (kv.getKey().compareTo(Constant.ROW_COUNT) == 0) {
                 return Long.parseLong(kv.getValue());
                 return Long.parseLong(kv.getValue());
             }
             }
         }
         }

+ 1 - 1
src/main/java/io/milvus/Response/ShowCollResponseWrapper.java

@@ -48,7 +48,7 @@ public class ShowCollResponseWrapper {
      *
      *
      * @return <code>CollectionInfo</code> information of the collection
      * @return <code>CollectionInfo</code> information of the collection
      */
      */
-    public CollectionInfo getCollectionInfo(@NonNull String name) {
+    public CollectionInfo getCollectionInfoByName(@NonNull String name) {
         for (int i = 0; i < response.getCollectionNamesCount(); ++i) {
         for (int i = 0; i < response.getCollectionNamesCount(); ++i) {
             if ( name.compareTo(response.getCollectionNames(i)) == 0) {
             if ( name.compareTo(response.getCollectionNames(i)) == 0) {
                 CollectionInfo info = new CollectionInfo(response.getCollectionNames(i), response.getCollectionIds(i),
                 CollectionInfo info = new CollectionInfo(response.getCollectionNames(i), response.getCollectionIds(i),

+ 1 - 0
src/main/java/io/milvus/param/Constant.java

@@ -32,6 +32,7 @@ public class Constant {
     public static final String METRIC_TYPE = "metric_type";
     public static final String METRIC_TYPE = "metric_type";
     public static final String ROUND_DECIMAL = "round_decimal";
     public static final String ROUND_DECIMAL = "round_decimal";
     public static final String PARAMS = "params";
     public static final String PARAMS = "params";
+    public static final String ROW_COUNT = "row_count";
 
 
     // max value for waiting loading collection/partition interval, unit: millisecond
     // max value for waiting loading collection/partition interval, unit: millisecond
     public static final Long MAX_WAITING_LOADING_INTERVAL = 2000L;
     public static final Long MAX_WAITING_LOADING_INTERVAL = 2000L;

+ 8 - 0
src/main/java/io/milvus/param/collection/FieldType.java

@@ -51,6 +51,14 @@ public class FieldType {
         this.autoID = builder.autoID;
         this.autoID = builder.autoID;
     }
     }
 
 
+    public int getDimension() {
+        if (typeParams.containsKey(Constant.VECTOR_DIM)) {
+            return Integer.valueOf(typeParams.get(Constant.VECTOR_DIM));
+        }
+
+        return 0;
+    }
+
     public static Builder newBuilder() {
     public static Builder newBuilder() {
         return new Builder();
         return new Builder();
     }
     }

+ 186 - 7
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -44,6 +44,7 @@ import java.util.*;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
 
 
 import static org.junit.jupiter.api.Assertions.*;
 import static org.junit.jupiter.api.Assertions.*;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 
 
 class MilvusServiceClientTest {
 class MilvusServiceClientTest {
     private final int testPort = 53019;
     private final int testPort = 53019;
@@ -1861,9 +1862,112 @@ class MilvusServiceClientTest {
         assertThrows(IllegalResponseException.class, wrapper::getDim);
         assertThrows(IllegalResponseException.class, wrapper::getDim);
     }
     }
 
 
+    @Test
+    void testDescCollResponseWrapper() {
+        String collName = "test";
+        String collDesc = "test col";
+        long collId = 100;
+        int shardNum = 10;
+        long utcTs = 9999;
+        List<String> aliases = Collections.singletonList("a1");
+
+        String fieldName = "f1";
+        String fieldDesc = "f1 field";
+        final boolean autoId = false;
+        final boolean primaryKey = true;
+        DataType dt = DataType.Double;
+        int dim = 256;
+        KeyValuePair kv = KeyValuePair.newBuilder()
+                .setKey(Constant.VECTOR_DIM).setValue(String.valueOf(dim)).build();
+        FieldSchema field = FieldSchema.newBuilder()
+                .setName(fieldName)
+                .setDescription(fieldDesc)
+                .setAutoID(autoId)
+                .setIsPrimaryKey(primaryKey)
+                .setDataType(dt)
+                .addTypeParams(kv)
+                .build();
+
+        CollectionSchema schema = CollectionSchema.newBuilder()
+                .setName(collName)
+                .setDescription(collDesc)
+                .addFields(field)
+                .build();
+
+        DescribeCollectionResponse response = DescribeCollectionResponse.newBuilder()
+                .setCollectionID(collId)
+                .addAllAliases(aliases)
+                .setShardsNum(shardNum)
+                .setCreatedUtcTimestamp(utcTs)
+                .setSchema(schema)
+                .build();
+
+        DescCollResponseWrapper wrapper = new DescCollResponseWrapper(response);
+        assertEquals(collName, wrapper.getCollectionName());
+        assertEquals(collDesc, wrapper.getCollectionDescription());
+        assertEquals(collId, wrapper.getCollectionID());
+        assertEquals(shardNum, wrapper.getShardNumber());
+        assertEquals(aliases.size(), wrapper.getAliases().size());
+        assertEquals(utcTs, wrapper.getCreatedUtcTimestamp());
+        assertEquals(1, wrapper.getFields().size());
+
+        assertNull(wrapper.getFieldByName(""));
+
+        FieldType ft = wrapper.getFieldByName(fieldName);
+        assertEquals(fieldName, ft.getName());
+        assertEquals(fieldDesc, ft.getDescription());
+        assertEquals(dt, ft.getDataType());
+        assertEquals(autoId, ft.isAutoID());
+        assertEquals(primaryKey, ft.isPrimaryKey());
+        assertEquals(dim, ft.getDimension());
+
+        assertFalse(wrapper.toString().isEmpty());
+    }
+
+    @Test
+    void testDescIndexResponseWrapper() {
+        final long indexId = 888;
+        String indexName = "idx";
+        String fieldName = "f1";
+        IndexType indexType = IndexType.IVF_FLAT;
+        MetricType metricType = MetricType.IP;
+        String extraParam = "{nlist:10}";
+        KeyValuePair kvIndexType = KeyValuePair.newBuilder()
+                .setKey(Constant.INDEX_TYPE).setValue(indexType.name()).build();
+        KeyValuePair kvMetricType = KeyValuePair.newBuilder()
+                .setKey(Constant.METRIC_TYPE).setValue(metricType.name()).build();
+        KeyValuePair kvExtraParam = KeyValuePair.newBuilder()
+                .setKey(Constant.PARAMS).setValue(extraParam).build();
+        IndexDescription desc = IndexDescription.newBuilder()
+                .setIndexID(indexId)
+                .setIndexName(indexName)
+                .setFieldName(fieldName)
+                .addParams(kvIndexType)
+                .addParams(kvMetricType)
+                .addParams(kvExtraParam)
+                .build();
+        DescribeIndexResponse response = DescribeIndexResponse.newBuilder()
+                .addIndexDescriptions(desc)
+                .build();
+
+        DescIndexResponseWrapper wrapper = new DescIndexResponseWrapper(response);
+        assertEquals(1, wrapper.getIndexDescriptions().size());
+        assertNull(wrapper.getIndexDescByFieldName(""));
+
+        DescIndexResponseWrapper.IndexDesc indexDesc = wrapper.getIndexDescByFieldName(fieldName);
+        assertEquals(indexId, indexDesc.getId());
+        assertEquals(indexName, indexDesc.getIndexName());
+        assertEquals(fieldName, indexDesc.getFieldName());
+        assertEquals(indexType, indexDesc.getIndexType());
+        assertEquals(metricType, indexDesc.getMetricType());
+        assertEquals(0, extraParam.compareTo(indexDesc.getExtraParam()));
+
+        assertFalse(wrapper.toString().isEmpty());
+    }
+
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     @Test
     @Test
-    void fieldDataWrapper() {
+    void testFieldDataWrapper() {
         // for float vector
         // for float vector
         long dim = 3;
         long dim = 3;
         List<Float> floatVectors = Arrays.asList(1F, 2F, 3F, 4F, 5F, 6F);
         List<Float> floatVectors = Arrays.asList(1F, 2F, 3F, 4F, 5F, 6F);
@@ -1966,15 +2070,15 @@ class MilvusServiceClientTest {
     }
     }
 
 
     @Test
     @Test
-    void getCollStatResponseWrapper() {
+    void testGetCollStatResponseWrapper() {
         GetCollectionStatisticsResponse response = GetCollectionStatisticsResponse.newBuilder()
         GetCollectionStatisticsResponse response = GetCollectionStatisticsResponse.newBuilder()
-                .addStats(KeyValuePair.newBuilder().setKey("row_count").setValue("invalid").build())
+                .addStats(KeyValuePair.newBuilder().setKey(Constant.ROW_COUNT).setValue("invalid").build())
                 .build();
                 .build();
         GetCollStatResponseWrapper invalidWrapper = new GetCollStatResponseWrapper(response);
         GetCollStatResponseWrapper invalidWrapper = new GetCollStatResponseWrapper(response);
         assertThrows(NumberFormatException.class, invalidWrapper::getRowCount);
         assertThrows(NumberFormatException.class, invalidWrapper::getRowCount);
 
 
         response = GetCollectionStatisticsResponse.newBuilder()
         response = GetCollectionStatisticsResponse.newBuilder()
-                .addStats(KeyValuePair.newBuilder().setKey("row_count").setValue("10").build())
+                .addStats(KeyValuePair.newBuilder().setKey(Constant.ROW_COUNT).setValue("10").build())
                 .build();
                 .build();
         GetCollStatResponseWrapper wrapper = new GetCollStatResponseWrapper(response);
         GetCollStatResponseWrapper wrapper = new GetCollStatResponseWrapper(response);
         assertEquals(10, wrapper.getRowCount());
         assertEquals(10, wrapper.getRowCount());
@@ -1985,7 +2089,24 @@ class MilvusServiceClientTest {
     }
     }
 
 
     @Test
     @Test
-    void MutationResultWrapper() {
+    void testGetPartStatResponseWrapper() {
+        final long rowCount = 500;
+        KeyValuePair kvStat = KeyValuePair.newBuilder()
+                .setKey(Constant.ROW_COUNT).setValue(String.valueOf(rowCount)).build();
+        GetPartitionStatisticsResponse response = GetPartitionStatisticsResponse.newBuilder()
+                .addStats(kvStat).build();
+
+        GetPartStatResponseWrapper wrapper = new GetPartStatResponseWrapper(response);
+        assertEquals(rowCount, wrapper.getRowCount());
+
+        response = GetPartitionStatisticsResponse.newBuilder().build();
+
+        wrapper = new GetPartStatResponseWrapper(response);
+        assertEquals(0, wrapper.getRowCount());
+    }
+
+    @Test
+    void testMutationResultWrapper() {
         List<Long> nID = Arrays.asList(1L, 2L, 3L);
         List<Long> nID = Arrays.asList(1L, 2L, 3L);
         MutationResult results = MutationResult.newBuilder()
         MutationResult results = MutationResult.newBuilder()
                 .setInsertCnt(nID.size())
                 .setInsertCnt(nID.size())
@@ -2026,7 +2147,7 @@ class MilvusServiceClientTest {
     }
     }
 
 
     @Test
     @Test
-    void QueryResultsWrapper() {
+    void testQueryResultsWrapper() {
         String fieldName = "test";
         String fieldName = "test";
         QueryResults results = QueryResults.newBuilder()
         QueryResults results = QueryResults.newBuilder()
                 .addFieldsData(FieldData.newBuilder()
                 .addFieldsData(FieldData.newBuilder()
@@ -2040,7 +2161,7 @@ class MilvusServiceClientTest {
     }
     }
 
 
     @Test
     @Test
-    void SearchResultsWrapper() {
+    void testSearchResultsWrapper() {
         long topK = 5;
         long topK = 5;
         long numQueries = 2;
         long numQueries = 2;
         List<Long> longIDs = new ArrayList<>();
         List<Long> longIDs = new ArrayList<>();
@@ -2092,5 +2213,63 @@ class MilvusServiceClientTest {
         SearchResultsWrapper strWrapper = new SearchResultsWrapper(results);
         SearchResultsWrapper strWrapper = new SearchResultsWrapper(results);
         idScores = strWrapper.getIDScore(0);
         idScores = strWrapper.getIDScore(0);
         assertEquals(idScores.size(), topK);
         assertEquals(idScores.size(), topK);
+
+        idScores.forEach((score)->assertFalse(score.toString().isEmpty()));
+    }
+
+    @Test
+    void testShowCollResponseWrapper() {
+        List<String> names = Arrays.asList("coll_1", "coll_2");
+        List<Long> ids = Arrays.asList(1L, 2L);
+        List<Long> ts = Arrays.asList(888L, 999L);
+        List<Long> inMemory = Arrays.asList(100L, 50L);
+        ShowCollectionsResponse response = ShowCollectionsResponse.newBuilder()
+                .addAllCollectionNames(names)
+                .addAllCollectionIds(ids)
+                .addAllCreatedUtcTimestamps(ts)
+                .addAllInMemoryPercentages(inMemory)
+                .build();
+
+        ShowCollResponseWrapper wrapper = new ShowCollResponseWrapper(response);
+        assertEquals(names.size(), wrapper.getCollectionsInfo().size());
+        assertFalse(wrapper.toString().isEmpty());
+
+        for (int i = 0; i < 2; ++i) {
+            ShowCollResponseWrapper.CollectionInfo info = wrapper.getCollectionInfoByName(names.get(i));
+            assertEquals(names.get(i).compareTo(info.getName()), 0);
+            assertEquals(ids.get(i), info.getId());
+            assertEquals(ts.get(i), info.getUtcTimestamp());
+            assertEquals(inMemory.get(i), info.getInMemoryPercentage());
+
+            assertFalse(info.toString().isEmpty());
+        }
+    }
+
+    @Test
+    void testShowPartResponseWrapper() {
+        List<String> names = Arrays.asList("part_1", "part_2");
+        List<Long> ids = Arrays.asList(1L, 2L);
+        List<Long> ts = Arrays.asList(888L, 999L);
+        List<Long> inMemory = Arrays.asList(100L, 50L);
+        ShowPartitionsResponse response = ShowPartitionsResponse.newBuilder()
+                .addAllPartitionNames(names)
+                .addAllPartitionIDs(ids)
+                .addAllCreatedUtcTimestamps(ts)
+                .addAllInMemoryPercentages(inMemory)
+                .build();
+
+        ShowPartResponseWrapper wrapper = new ShowPartResponseWrapper(response);
+        assertEquals(names.size(), wrapper.getPartitionsInfo().size());
+        assertFalse(wrapper.toString().isEmpty());
+
+        for (int i = 0; i < 2; ++i) {
+            ShowPartResponseWrapper.PartitionInfo info = wrapper.getPartitionInfoByName(names.get(i));
+            assertEquals(names.get(i).compareTo(info.getName()), 0);
+            assertEquals(ids.get(i), info.getId());
+            assertEquals(ts.get(i), info.getUtcTimestamp());
+            assertEquals(inMemory.get(i), info.getInMemoryPercentage());
+
+            assertFalse(info.toString().isEmpty());
+        }
     }
     }
 }
 }