Prechádzať zdrojové kódy

Fix a bug of nullable Array type (#1367)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 týždňov pred
rodič
commit
6863541086

+ 2 - 2
docker-compose.yml

@@ -32,7 +32,7 @@ services:
 
   standalone:
     container_name: milvus-javasdk-test-standalone
-    image: milvusdb/milvus:v2.5.4
+    image: milvusdb/milvus:v2.5.8
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcd:2379
@@ -77,7 +77,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-test-slave-standalone
-    image: milvusdb/milvus:v2.5.7
+    image: milvusdb/milvus:v2.5.8
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcdslave:2379

+ 12 - 0
examples/src/main/java/io/milvus/v1/NullAndDefaultExample.java

@@ -36,6 +36,7 @@ public class NullAndDefaultExample {
                 .addOutField("nullable_test")
                 .addOutField("default_test")
                 .addOutField("nullable_default")
+                .addOutField("nullable_array")
                 .build());
         QueryResultsWrapper queryWrapper = new QueryResultsWrapper(queryRet.getData());
         System.out.println("\nQuery with expression: " + expr);
@@ -83,6 +84,14 @@ public class NullAndDefaultExample {
                         .withMaxLength(64)
                         .withDefaultValue("I am default value")
                         .withNullable(true)
+                        .build(),
+                FieldType.newBuilder()
+                        .withName("nullable_array")
+                        .withDataType(DataType.Array)
+                        .withElementType(DataType.VarChar)
+                        .withMaxCapacity(10)
+                        .withMaxLength(100)
+                        .withNullable(true)
                         .build()
         );
 
@@ -130,6 +139,9 @@ public class NullAndDefaultExample {
                 row.addProperty("nullable_test", i);
             } else {
                 row.add("nullable_test", JsonNull.INSTANCE);
+
+                List<String> arr = Arrays.asList("A", "B", "C");
+                row.add("nullable_array", gson.toJsonTree(arr));
             }
 
             // some values are default value

+ 12 - 1
examples/src/main/java/io/milvus/v2/NullAndDefaultExample.java

@@ -28,7 +28,7 @@ public class NullAndDefaultExample {
         QueryResp queryRet = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .filter(expr)
-                .outputFields(Arrays.asList("nullable_test", "default_test", "nullable_default"))
+                .outputFields(Arrays.asList("nullable_test", "default_test", "nullable_default", "nullable_array"))
                 .build());
         System.out.println("\nQuery with expression: " + expr);
         List<QueryResp.QueryResult> records = queryRet.getQueryResults();
@@ -81,6 +81,14 @@ public class NullAndDefaultExample {
                 .isNullable(true)
                 .defaultValue("I am default value")
                 .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("nullable_array")
+                .dataType(DataType.Array)
+                .elementType(DataType.VarChar)
+                .maxCapacity(10)
+                .maxLength(100)
+                .isNullable(true)
+                .build());
 
         List<IndexParam> indexes = new ArrayList<>();
         indexes.add(IndexParam.builder()
@@ -111,6 +119,9 @@ public class NullAndDefaultExample {
                 row.addProperty("nullable_test", i);
             } else {
                 row.add("nullable_test", JsonNull.INSTANCE);
+
+                List<String> arr = Arrays.asList("A", "B", "C");
+                row.add("nullable_array", gson.toJsonTree(arr));
             }
 
             // some values are default value

+ 7 - 7
sdk-core/src/main/java/io/milvus/param/ParamUtils.java

@@ -1329,40 +1329,40 @@ public class ParamUtils {
             case UNRECOGNIZED:
                 throw new ParamException("Cannot support this dataType:" + dataType);
             case Int64: {
-                List<Long> longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
+                List<Long> longs = objects.stream().map(p -> (p == null) ? null : (Long) p).collect(Collectors.toList());
                 LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
                 return ScalarField.newBuilder().setLongData(longArray).build();
             }
             case Int32:
             case Int16:
             case Int8: {
-                List<Integer> integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList());
+                List<Integer> integers = objects.stream().map(p -> (p == null) ? null : (p instanceof Short ? ((Short) p).intValue() : (Integer) p)).collect(Collectors.toList());
                 IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
                 return ScalarField.newBuilder().setIntData(intArray).build();
             }
             case Bool: {
-                List<Boolean> booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
+                List<Boolean> booleans = objects.stream().map(p -> (p == null) ? null : (Boolean) p).collect(Collectors.toList());
                 BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
                 return ScalarField.newBuilder().setBoolData(boolArray).build();
             }
             case Float: {
-                List<Float> floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
+                List<Float> floats = objects.stream().map(p -> (p == null) ? null : (Float) p).collect(Collectors.toList());
                 FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
                 return ScalarField.newBuilder().setFloatData(floatArray).build();
             }
             case Double: {
-                List<Double> doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
+                List<Double> doubles = objects.stream().map(p -> (p == null) ? null : (Double) p).collect(Collectors.toList());
                 DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
                 return ScalarField.newBuilder().setDoubleData(doubleArray).build();
             }
             case String:
             case VarChar: {
-                List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
+                List<String> strings = objects.stream().map(p -> (p == null) ? null : (String) p).collect(Collectors.toList());
                 StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
                 return ScalarField.newBuilder().setStringData(stringArray).build();
             }
             case JSON: {
-                List<ByteString> byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(p.toString()))
+                List<ByteString> byteStrings = objects.stream().map(p -> (p == null) ? null : ByteString.copyFromUtf8(p.toString()))
                         .collect(Collectors.toList());
                 JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build();
                 return ScalarField.newBuilder().setJsonData(jsonArray).build();

+ 13 - 7
sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -251,13 +251,6 @@ public class FieldDataWrapper {
                 return packData;
             }
             case Array:
-                List<List<?>> array = new ArrayList<>();
-                ArrayArray arrArray = fieldData.getScalars().getArrayData();
-                for (int i = 0; i < arrArray.getDataCount(); i++) {
-                    ScalarField scalar = arrArray.getData(i);
-                    array.add(getScalarData(arrArray.getElementType(), scalar, null));
-                }
-                return array;
             case Int64:
             case Int32:
             case Int16:
@@ -308,6 +301,19 @@ public class FieldDataWrapper {
             case JSON:
                 List<ByteString> dataList = scalar.getJsonData().getDataList();
                 return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
+            case Array:
+                List<List<?>> array = new ArrayList<>();
+                ArrayArray arrArray = fieldData.getScalars().getArrayData();
+                boolean nullable = validData != null && validData.size() == arrArray.getDataCount();
+                for (int i = 0; i < arrArray.getDataCount(); i++) {
+                    if (nullable && validData.get(i) == Boolean.FALSE) {
+                        array.add(null);
+                    } else {
+                        ScalarField rowData = arrArray.getData(i);
+                        array.add(getScalarData(arrArray.getElementType(), rowData, null));
+                    }
+                }
+                return array;
             default:
                 return new ArrayList<>();
         }

+ 34 - 16
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -70,6 +70,7 @@ import org.testcontainers.milvus.MilvusContainer;
 
 import java.nio.ByteBuffer;
 import java.util.*;
+import java.util.function.Function;
 
 @Testcontainers(disabledWithoutDocker = true)
 class MilvusClientV2DockerTest {
@@ -1892,6 +1893,13 @@ class MilvusClientV2DockerTest {
                 .isNullable(Boolean.TRUE)
                 .maxLength(100)
                 .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("arr")
+                .dataType(DataType.Array)
+                .elementType(DataType.Int32)
+                .isNullable(Boolean.TRUE)
+                .maxCapacity(100)
+                .build());
 
         List<IndexParam> indexParams = new ArrayList<>();
         indexParams.add(IndexParam.builder()
@@ -1920,7 +1928,11 @@ class MilvusClientV2DockerTest {
             } else {
 //                row.add("flag", JsonNull.INSTANCE);
                 row.addProperty("desc", "AAA");
+
+                List<Integer> arr = Arrays.asList(5, 6);
+                row.add("arr", JsonUtils.toJsonTree(arr));
             }
+
             data.add(row);
         }
 
@@ -1930,11 +1942,30 @@ class MilvusClientV2DockerTest {
                 .build());
         Assertions.assertEquals(10, insertResp.getInsertCnt());
 
+        Function<Map<String, Object>, Void> checkFunc =
+                entity -> {
+                    long id = (long)entity.get("id");
+                    if (id%2 == 0) {
+                        Assertions.assertEquals((int)id, entity.get("flag"));
+                        Assertions.assertNull(entity.get("desc"));
+                        Assertions.assertNull(entity.get("arr"));
+                    } else {
+                        Assertions.assertEquals(10, entity.get("flag"));
+                        Assertions.assertEquals("AAA", entity.get("desc"));
+                        Object obj = entity.get("arr");
+                        Assertions.assertInstanceOf(List.class, obj);
+                        List<Integer> arr = (List<Integer>)obj;
+                        Assertions.assertEquals(2, arr.size());
+                        Assertions.assertEquals(5, arr.get(0));
+                        Assertions.assertEquals(6, arr.get(1));
+                    }
+                    return null;
+                };
         // query
         QueryResp queryResp = client.query(QueryReq.builder()
                 .collectionName(randomCollectionName)
                 .filter("id >= 0")
-                .outputFields(Arrays.asList("desc", "flag"))
+                .outputFields(Arrays.asList("desc", "flag", "arr"))
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());
         List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
@@ -1942,14 +1973,7 @@ class MilvusClientV2DockerTest {
         System.out.println("Query results:");
         for (QueryResp.QueryResult result : queryResults) {
             Map<String, Object> entity = result.getEntity();
-            long id = (long)entity.get("id");
-            if (id%2 == 0) {
-                Assertions.assertEquals((int)id, entity.get("flag"));
-                Assertions.assertNull(entity.get("desc"));
-            } else {
-                Assertions.assertEquals(10, entity.get("flag"));
-                Assertions.assertEquals("AAA", entity.get("desc"));
-            }
+            checkFunc.apply(entity);
             System.out.println(result);
         }
 
@@ -1970,13 +1994,7 @@ class MilvusClientV2DockerTest {
         for (SearchResp.SearchResult result : firstResults) {
             long id = (long)result.getId();
             Map<String, Object> entity = result.getEntity();
-            if (id%2 == 0) {
-                Assertions.assertEquals((int)id, entity.get("flag"));
-                Assertions.assertNull(entity.get("desc"));
-            } else {
-                Assertions.assertEquals(10, entity.get("flag"));
-                Assertions.assertEquals("AAA", entity.get("desc"));
-            }
+            checkFunc.apply(entity);
             System.out.println(result);
         }
     }