Browse Source

Support nullable and default value (#1083)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 9 months ago
parent
commit
c83a0e7ada

+ 223 - 12
src/main/java/io/milvus/param/ParamUtils.java

@@ -30,6 +30,8 @@ import io.milvus.param.collection.FieldType;
 import io.milvus.param.dml.*;
 import io.milvus.param.dml.ranker.BaseRanker;
 import io.milvus.response.DescCollResponseWrapper;
+import io.milvus.v2.exception.ErrorCode;
+import io.milvus.v2.exception.MilvusClientException;
 import lombok.Builder;
 import lombok.Getter;
 import lombok.NonNull;
@@ -107,6 +109,37 @@ public class ParamUtils {
         }
     }
 
+    private static boolean checkNullableFieldData(FieldType fieldSchema, Object value, boolean verifyElementType) {
+        if (verifyElementType) {
+            return false; // array element check, go to 1.2 and 2.3
+        }
+
+        // nullable and default value check
+        // 1. if the field is nullable, user can input null for column-based insert
+        //    1) if user input JsonNull, this value is replaced by default value
+        //    2) if user input JsonObject, infer this value by type
+        // 2. if the field is not nullable, user can input null for column-based insert
+        //    1) if user input JsonNull, and default value is null, throw error
+        //    2) if user input JsonNull, and default value is not null, this value is replaced by default value
+        //    3) if user input JsonObject, infer this value by type
+        if (fieldSchema.isNullable()) {
+            if (value == null) {
+                return true; // 1.1
+            }
+        } else {
+            if (value == null) {
+                if (fieldSchema.getDefaultValue() == null) {
+                    String msg = "Field '%s' is not nullable but the input value is null";
+                    throw new ParamException(String.format(msg, fieldSchema.getName())); // 2.1
+                } else {
+                    return true; // 2.2
+                }
+            }
+        }
+
+        return false; // go to 1.2 and 2.3
+    }
+
     public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean verifyElementType) {
         HashMap<DataType, String> errMsgs = getTypeErrorMsgForColumnInsert();
         DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType();
@@ -181,6 +214,9 @@ public class ParamUtils {
                 break;
             case Int64:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof Long)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -190,6 +226,9 @@ public class ParamUtils {
             case Int16:
             case Int8:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof Short) && !(value instanceof Integer)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -197,6 +236,9 @@ public class ParamUtils {
                 break;
             case Bool:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof Boolean)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -204,6 +246,9 @@ public class ParamUtils {
                 break;
             case Float:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof Float)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -211,6 +256,9 @@ public class ParamUtils {
                 break;
             case Double:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof Double)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -219,6 +267,9 @@ public class ParamUtils {
             case VarChar:
             case String:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof String)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -229,6 +280,9 @@ public class ParamUtils {
                 break;
             case JSON:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof JsonElement)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -236,6 +290,9 @@ public class ParamUtils {
                 break;
             case Array:
                 for (Object value : values) {
+                    if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
+                        continue;
+                    }
                     if (!(value instanceof List)) {
                         throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
                     }
@@ -253,8 +310,32 @@ public class ParamUtils {
     }
 
     public static Object checkFieldValue(FieldType fieldSchema, JsonElement value) {
-        HashMap<DataType, String> errMsgs = getTypeErrorMsgForRowInsert();
         DataType dataType = fieldSchema.getDataType();
+        // nullable and default value check
+        // 1. if the field is nullable, user can input JsonNull/JsonObject(for row-based insert)
+        //    1) if user input JsonNull, this value is replaced by default value
+        //    2) if user input JsonObject, infer this value by type
+        // 2. if the field is not nullable, user can input JsonNull/JsonObject(for row-based insert)
+        //    1) if user input JsonNull, and default value is null, throw error
+        //    2) if user input JsonNull, and default value is not null, this value is replaced by default value
+        //    3) if user input JsonObject, infer this value by type
+        if (fieldSchema.isNullable()) {
+            if (value instanceof JsonNull) {
+                return fieldSchema.getDefaultValue(); // 1.1
+            }
+        } else {
+            if (value instanceof JsonNull) {
+                if (fieldSchema.getDefaultValue() == null) {
+                    String msg = "Field '%s' is not nullable but the input value is null";
+                    throw new ParamException(String.format(msg, fieldSchema.getName())); // 2.1
+                } else {
+                    return fieldSchema.getDefaultValue(); // 2.2
+                }
+            }
+        }
+
+        // 1.2 and 1.3, infer value by type
+        HashMap<DataType, String> errMsgs = getTypeErrorMsgForRowInsert();
 
         switch (dataType) {
             case FloatVector: {
@@ -569,21 +650,27 @@ public class ParamUtils {
 
                     // check normalField
                     JsonElement rowFieldData = row.get(fieldName);
-                    if (rowFieldData != null && !rowFieldData.isJsonNull()) {
+                    if (rowFieldData == null) {
+                        // check if autoId
                         if (fieldType.isAutoID()) {
-                            String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
-                            throw new ParamException(msg);
+                            continue;
                         }
-                        Object fieldValue = checkFieldValue(fieldType, rowFieldData);
-                        insertDataInfo.getData().add(fieldValue);
-                        nameInsertInfo.put(fieldName, insertDataInfo);
-                    } else {
-                        // check if autoId
-                        if (!fieldType.isAutoID()) {
+                        // if the field doesn't have default value, require user provide the value
+                        if (!fieldType.isNullable() && fieldType.getDefaultValue() == null) {
                             String msg = String.format("The field: %s is not provided.", fieldType.getName());
                             throw new ParamException(msg);
                         }
+
+                        rowFieldData = JsonNull.INSTANCE;
+                    }
+
+                    if (fieldType.isAutoID()) {
+                        String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
+                        throw new ParamException(msg);
                     }
+                    Object fieldValue = checkFieldValue(fieldType, rowFieldData);
+                    insertDataInfo.getData().add(fieldValue);
+                    nameInsertInfo.put(fieldName, insertDataInfo);
                 }
 
                 // deal with dynamicField
@@ -976,10 +1063,22 @@ public class ParamUtils {
         DataType dataType = fieldType.getDataType();
         String fieldName = fieldType.getName();
         FieldData.Builder builder = FieldData.newBuilder();
+
         if (isVectorDataType(dataType)) {
             VectorField vectorField = genVectorField(dataType, objects);
             return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
         } else {
+            if (fieldType.isNullable() || fieldType.getDefaultValue() != null) {
+                List<Object> tempObjects = new ArrayList<>();
+                for (Object obj : objects) {
+                    builder.addValidData(obj != null);
+                    if (obj != null) {
+                        tempObjects.add(obj);
+                    }
+                }
+                objects = tempObjects;
+            }
+
             ScalarField scalarField = genScalarField(fieldType, objects);
             if (isDynamic) {
                 return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build();
@@ -1191,7 +1290,13 @@ public class ParamUtils {
                 .withAutoID(field.getAutoID())
                 .withDataType(field.getDataType())
                 .withElementType(field.getElementType())
-                .withIsDynamic(field.getIsDynamic());
+                .withIsDynamic(field.getIsDynamic())
+                .withNullable(field.getNullable());
+
+        Object defaultValue = valueFieldToObject(field.getDefaultValue(), field.getDataType());
+        if (field.getNullable() || defaultValue != null) {
+            builder.withDefaultValue(defaultValue);
+        }
 
         if (field.getIsDynamic()) {
             builder.withIsDynamic(true);
@@ -1218,7 +1323,20 @@ public class ParamUtils {
                 .setAutoID(field.isAutoID())
                 .setDataType(field.getDataType())
                 .setElementType(field.getElementType())
-                .setIsDynamic(field.isDynamic());
+                .setIsDynamic(field.isDynamic())
+                .setNullable(field.isNullable());
+        DataType dType = field.getDataType();
+        if (!ParamUtils.isVectorDataType(dType) && !field.isPrimaryKey()) {
+            ValueField value = ParamUtils.objectToValueField(field.getDefaultValue(), dType);
+            if (value != null) {
+                builder.setDefaultValue(value);
+            } else if (field.getDefaultValue() != null) {
+                String msg = String.format("Illegal default value for %s type field. Please use Short for Int8/Int16 fields, " +
+                        "Short/Integer for Int32 fields, Short/Integer/Long for Int64 fields, Boolean for Bool fields, " +
+                        "String for Varchar fields, JsonObject for JSON fields.", dType.name());
+                throw new MilvusClientException(ErrorCode.INVALID_PARAMS, msg);
+            }
+        }
 
         // assemble typeParams for CollectionSchema
         List<KeyValuePair> typeParamsList = AssembleKvPair(field.getTypeParams());
@@ -1229,6 +1347,99 @@ public class ParamUtils {
         return builder.build();
     }
 
+    public static ValueField objectToValueField(Object obj, DataType dataType) {
+        if (obj == null) {
+            return null;
+        }
+
+        ValueField.Builder builder = ValueField.newBuilder();
+        switch (dataType) {
+            case Int8:
+            case Int16:
+                if (obj instanceof Short) {
+                    return builder.setIntData(((Short)obj).intValue()).build();
+                }
+                break;
+            case Int32:
+                if (obj instanceof Short) {
+                    return builder.setIntData(((Short)obj).intValue()).build();
+                } else if (obj instanceof Integer) {
+                    return builder.setIntData((Integer) obj).build();
+                }
+                break;
+            case Int64:
+                if (obj instanceof Short) {
+                    return builder.setLongData(((Short)obj).longValue()).build();
+                } else if (obj instanceof Integer) {
+                    return builder.setLongData(((Integer)obj).longValue()).build();
+                } else if (obj instanceof Long) {
+                    return builder.setLongData((Long) obj).build();
+                }
+                break;
+            case Float:
+                if (obj instanceof Float) {
+                    return builder.setFloatData((Float) obj).build();
+                }
+                break;
+            case Double:
+                if (obj instanceof Float) {
+                    return builder.setDoubleData(((Float)obj).doubleValue()).build();
+                } else if (obj instanceof Double) {
+                    return builder.setDoubleData((Double) obj).build();
+                }
+                break;
+            case Bool:
+                if (obj instanceof Boolean) {
+                    return builder.setBoolData((Boolean) obj).build();
+                }
+                break;
+            case VarChar:
+            case String:
+                if (obj instanceof String) {
+                    return builder.setStringData((String) obj).build();
+                }
+                break;
+            case JSON:
+                if (obj instanceof JsonObject) {
+                    return builder.setStringData(obj.toString()).build();
+                }
+                break;
+            default:
+                break;
+        }
+        return null;
+    }
+
+    public static Object valueFieldToObject(ValueField value, DataType dataType) {
+        if (value == null || value.getDataCase() == ValueField.DataCase.DATA_NOT_SET) {
+            return null;
+        }
+
+        switch (dataType) {
+            case Int8:
+            case Int16:
+                return (short) value.getIntData();
+            case Int32:
+                return value.getIntData();
+            case Int64:
+                return value.getLongData();
+            case Float:
+                return value.getFloatData();
+            case Double:
+                return value.getDoubleData();
+            case Bool:
+                return value.getBoolData();
+            case VarChar:
+            case String:
+                return value.getStringData();
+            case JSON:
+                return new Gson().fromJson(value.getStringData(), JsonObject.class);
+            default:
+                break;
+        }
+        return null;
+    }
+
     public static List<KeyValuePair> AssembleKvPair(Map<String, String> sourceMap) {
         List<KeyValuePair> result = new ArrayList<>();
 

+ 58 - 0
src/main/java/io/milvus/param/collection/FieldType.java

@@ -47,6 +47,8 @@ public class FieldType {
     private final boolean partitionKey;
     private final boolean isDynamic;
     private final DataType elementType;
+    private final boolean nullable;
+    private final Object defaultValue;
 
     private FieldType(@NonNull Builder builder){
         this.name = builder.name;
@@ -58,6 +60,8 @@ public class FieldType {
         this.partitionKey = builder.partitionKey;
         this.isDynamic = builder.isDynamic;
         this.elementType = builder.elementType;
+        this.nullable = builder.nullable;
+        this.defaultValue = builder.defaultValue;
     }
 
     public int getDimension() {
@@ -101,6 +105,9 @@ public class FieldType {
         private boolean partitionKey = false;
         private boolean isDynamic = false;
         private DataType elementType = DataType.None; // only for Array type field
+        private boolean nullable = false; // only for scalar fields(not include Array fields)
+        private Object defaultValue = null; // only for scalar fields
+        private boolean enableDefaultValue = false; // a flag to pass the default value to server or not
 
         private Builder() {
         }
@@ -255,6 +262,51 @@ public class FieldType {
             return this;
         }
 
+        /**
+         * Sets this field is nullable or not.
+         * Primary key field, vector fields, Array fields cannot be nullable.
+         *
+         *  1. if the field is nullable, user can input JsonNull/JsonObject(for row-based insert), or input null/object(for column-based insert)
+         *     1) if user input JsonNull, this value is replaced by default value
+         *     2) if user input JsonObject, infer this value by type
+         *  2. if the field is not nullable, user can input JsonNull/JsonObject(for row-based insert), or input null/object(for column-based insert)
+         *     1) if user input JsonNull, and default value is null, throw error
+         *     2) if user input JsonNull, and default value is not null, this value is replaced by default value
+         *     3) if user input JsonObject, infer this value by type
+         *
+         * @param nullable true is nullable, false is not
+         * @return <code>Builder</code>
+         */
+        public Builder withNullable(boolean nullable) {
+            this.nullable = nullable;
+            return this;
+        }
+
+        /**
+         * Sets default value of this field.
+         * If nullable is false, the default value cannot be null. If nullable is true, the default value can be null.
+         * Only scalar fields(not include Array field) support default value.
+         * The default value type must obey the following rule:
+         * - Boolean for Bool fields
+         * - Short for Int8/Int16 fields
+         * - Short/Integer for Int32 fields
+         * - Short/Integer/Long for Int64 fields
+         * - Float for Float fields
+         * - Double for Double fields
+         * - String for Varchar fields
+         * - JsonObject for JSON fields
+         *
+         * For JSON field, you can use JsonNull.INSTANCE as default value. For other scalar fields, you can use null as default value.
+         *
+         * @param obj the default value
+         * @return <code>Builder</code>
+         */
+        public Builder withDefaultValue(Object obj) {
+            this.defaultValue = obj;
+            this.enableDefaultValue = true;
+            return this;
+        }
+
         /**
          * Verifies parameters and creates a new {@link FieldType} instance.
          *
@@ -331,6 +383,12 @@ public class FieldType {
                 }
             }
 
+            // check the input here to pop error messages earlier
+            if (enableDefaultValue && defaultValue == null && !nullable) {
+                String msg = String.format("Default value cannot be null for field '%s' that is defined as nullable == false.", name);
+                throw new ParamException(msg);
+            }
+
             return new FieldType(this);
         }
     }

+ 22 - 10
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -33,7 +33,6 @@ import java.nio.ByteOrder;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.SortedMap;
-import java.util.TreeMap;
 import java.util.stream.Collectors;
 
 import com.google.protobuf.ByteString;
@@ -256,7 +255,7 @@ public class FieldDataWrapper {
                 ArrayArray arrArray = fieldData.getScalars().getArrayData();
                 for (int i = 0; i < arrArray.getDataCount(); i++) {
                     ScalarField scalar = arrArray.getData(i);
-                    array.add(getScalarData(arrArray.getElementType(), scalar));
+                    array.add(getScalarData(arrArray.getElementType(), scalar, null));
                 }
                 return array;
             case Int64:
@@ -269,30 +268,43 @@ public class FieldDataWrapper {
             case VarChar:
             case String:
             case JSON:
-                return getScalarData(dt, fieldData.getScalars());
+                return getScalarData(dt, fieldData.getScalars(), fieldData.getValidDataList());
             default:
                 throw new IllegalResponseException("Unsupported data type returned by FieldData");
         }
     }
 
-    private List<?> getScalarData(DataType dt, ScalarField scalar) {
+    private List<?> setNoneData(List<?> data, List<Boolean> validData) {
+        if (validData != null && validData.size() == data.size()) {
+            List<?> newData = new ArrayList<>(data); // copy the list since the data is come from grpc is not mutable
+            for (int i = 0; i < validData.size(); i++) {
+                if (validData.get(i) == Boolean.FALSE) {
+                    newData.set(i, null);
+                }
+            }
+            return newData;
+        }
+        return data;
+    }
+
+    private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> validData) {
         switch (dt) {
             case Int64:
-                return scalar.getLongData().getDataList();
+                return setNoneData(scalar.getLongData().getDataList(), validData);
             case Int32:
             case Int16:
             case Int8:
-                return scalar.getIntData().getDataList();
+                return setNoneData(scalar.getIntData().getDataList(), validData);
             case Bool:
-                return scalar.getBoolData().getDataList();
+                return setNoneData(scalar.getBoolData().getDataList(), validData);
             case Float:
-                return scalar.getFloatData().getDataList();
+                return setNoneData(scalar.getFloatData().getDataList(), validData);
             case Double:
-                return scalar.getDoubleData().getDataList();
+                return setNoneData(scalar.getDoubleData().getDataList(), validData);
             case VarChar:
             case String:
                 ProtocolStringList protoStrList = scalar.getStringData().getDataList();
-                return protoStrList.subList(0, protoStrList.size());
+                return setNoneData(protoStrList.subList(0, protoStrList.size()), validData);
             case JSON:
                 List<ByteString> dataList = scalar.getJsonData().getDataList();
                 return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());

+ 1 - 2
src/main/java/io/milvus/response/QueryResultsWrapper.java

@@ -155,7 +155,6 @@ public class QueryResultsWrapper extends RowRecordWrapper {
                         return innerObj;
                     }
                 }
-                throw new ParamException("The key name is not found");
             }
 
             return obj;
@@ -180,7 +179,7 @@ public class QueryResultsWrapper extends RowRecordWrapper {
         public String toString() {
             List<String> pairs = new ArrayList<>();
             fieldValues.forEach((keyName, fieldValue) -> {
-                pairs.add(keyName + ":" + fieldValue.toString());
+                pairs.add(keyName + ":" + fieldValue);
             });
             return pairs.toString();
         }

+ 1 - 1
src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -335,7 +335,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
         public String toString() {
             List<String> pairs = new ArrayList<>();
             fieldValues.forEach((keyName, fieldValue) -> {
-                pairs.add(keyName + ":" + fieldValue.toString());
+                pairs.add(keyName + ":" + fieldValue);
             });
 
             if (strID.isEmpty()) {

+ 12 - 0
src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java

@@ -42,4 +42,16 @@ public class AddFieldReq {
     private Integer dimension;
     private io.milvus.v2.common.DataType elementType;
     private Integer maxCapacity;
+    @Builder.Default
+    private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields)
+    @Builder.Default
+    private Object defaultValue = null; // only for scalar fields
+    @Builder.ObtainVia(field = "hiddenField")
+    private boolean enableDefaultValue = false; // a flag to pass the default value to server or not
+
+    AddFieldReq setDefaultValue(Object obj) {
+        enableDefaultValue = true; // automatically set this flag
+        this.defaultValue = obj;
+        return this;
+    }
 }

+ 13 - 0
src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java

@@ -92,6 +92,13 @@ public class CreateCollectionReq {
         private List<CreateCollectionReq.FieldSchema> fieldSchemaList = new ArrayList<>();
 
         public CollectionSchema addField(AddFieldReq addFieldReq) {
+            // check the input here to pop error messages earlier
+            if (addFieldReq.isEnableDefaultValue() && addFieldReq.getDefaultValue() == null
+                    && addFieldReq.getIsNullable() == Boolean.FALSE) {
+                String msg = String.format("Default value cannot be null for field '%s' that is defined as nullable == false.", addFieldReq.getFieldName());
+                throw new MilvusClientException(ErrorCode.INVALID_PARAMS, msg);
+            }
+
             CreateCollectionReq.FieldSchema fieldSchema = FieldSchema.builder()
                     .name(addFieldReq.getFieldName())
                     .dataType(addFieldReq.getDataType())
@@ -99,6 +106,8 @@ public class CreateCollectionReq {
                     .isPrimaryKey(addFieldReq.getIsPrimaryKey())
                     .isPartitionKey(addFieldReq.getIsPartitionKey())
                     .autoID(addFieldReq.getAutoID())
+                    .isNullable(addFieldReq.getIsNullable())
+                    .defaultValue(addFieldReq.getDefaultValue())
                     .build();
             if (addFieldReq.getDataType().equals(DataType.Array)) {
                 if (addFieldReq.getElementType() == null) {
@@ -147,5 +156,9 @@ public class CreateCollectionReq {
         private Boolean autoID = Boolean.FALSE;
         private DataType elementType;
         private Integer maxCapacity;
+        @Builder.Default
+        private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields)
+        @Builder.Default
+        private Object defaultValue = null; // only for scalar fields
     }
 }

+ 4 - 2
src/main/java/io/milvus/v2/utils/ConvertUtils.java

@@ -39,9 +39,8 @@ import java.util.stream.Collectors;
 
 public class ConvertUtils {
     public List<QueryResp.QueryResult> getEntities(QueryResults response) {
-        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(response);
         List<QueryResp.QueryResult> entities = new ArrayList<>();
-
+        // count(*) ?
         if(response.getFieldsDataList().stream().anyMatch(fieldData -> fieldData.getFieldName().equals("count(*)"))){
             Map<String, Object> countField = new HashMap<>();
             long numOfEntities = response.getFieldsDataList().stream().filter(fieldData -> fieldData.getFieldName().equals("count(*)")).map(FieldData::getScalars).collect(Collectors.toList()).get(0).getLongData().getData(0);
@@ -54,6 +53,9 @@ public class ConvertUtils {
 
             return entities;
         }
+
+        // normal query
+        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(response);
         queryResultsWrapper.getRowRecords().forEach(rowRecord -> {
             QueryResp.QueryResult queryResult = QueryResp.QueryResult.builder()
                     .entity(rowRecord.getFieldValues())

+ 16 - 9
src/main/java/io/milvus/v2/utils/DataUtils.java

@@ -20,6 +20,7 @@
 package io.milvus.v2.utils;
 
 import com.google.gson.JsonElement;
+import com.google.gson.JsonNull;
 import com.google.gson.JsonObject;
 import io.milvus.exception.ParamException;
 import io.milvus.grpc.*;
@@ -159,21 +160,27 @@ public class DataUtils {
 
                     // check normalField
                     JsonElement rowFieldData = row.get(fieldName);
-                    if (rowFieldData != null) {
+                    if (rowFieldData == null) {
+                        // check if autoId
                         if (fieldType.isAutoID()) {
-                            String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
-                            throw new ParamException(msg);
+                            continue;
                         }
-                        Object fieldValue = ParamUtils.checkFieldValue(fieldType, rowFieldData);
-                        insertDataInfo.getData().add(fieldValue);
-                        nameInsertInfo.put(fieldName, insertDataInfo);
-                    } else {
-                        // check if autoId
-                        if (!fieldType.isAutoID()) {
+                        // if the field doesn't have default value, require user provide the value
+                        if (!fieldType.isNullable() && fieldType.getDefaultValue() == null) {
                             String msg = String.format("The field: %s is not provided.", fieldType.getName());
                             throw new ParamException(msg);
                         }
+
+                        rowFieldData = JsonNull.INSTANCE;
+                    }
+
+                    if (fieldType.isAutoID()) {
+                        String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
+                        throw new ParamException(msg);
                     }
+                    Object fieldValue = ParamUtils.checkFieldValue(fieldType, rowFieldData);
+                    insertDataInfo.getData().add(fieldValue);
+                    nameInsertInfo.put(fieldName, insertDataInfo);
                 }
 
                 // deal with dynamicField

+ 23 - 7
src/main/java/io/milvus/v2/utils/SchemaUtils.java

@@ -19,10 +19,10 @@
 
 package io.milvus.v2.utils;
 
-import io.milvus.grpc.CollectionSchema;
-import io.milvus.grpc.DataType;
-import io.milvus.grpc.FieldSchema;
-import io.milvus.grpc.KeyValuePair;
+import io.milvus.grpc.*;
+import io.milvus.param.ParamUtils;
+import io.milvus.v2.exception.ErrorCode;
+import io.milvus.v2.exception.MilvusClientException;
 import io.milvus.v2.service.collection.request.CreateCollectionReq;
 
 import java.util.ArrayList;
@@ -30,14 +30,28 @@ import java.util.List;
 
 public class SchemaUtils {
     public static FieldSchema convertToGrpcFieldSchema(CreateCollectionReq.FieldSchema fieldSchema) {
-        FieldSchema schema = FieldSchema.newBuilder()
+        DataType dType = DataType.valueOf(fieldSchema.getDataType().name());
+        FieldSchema.Builder builder = FieldSchema.newBuilder()
                 .setName(fieldSchema.getName())
                 .setDescription(fieldSchema.getDescription())
-                .setDataType(DataType.valueOf(fieldSchema.getDataType().name()))
+                .setDataType(dType)
                 .setIsPrimaryKey(fieldSchema.getIsPrimaryKey())
                 .setIsPartitionKey(fieldSchema.getIsPartitionKey())
                 .setAutoID(fieldSchema.getAutoID())
-                .build();
+                .setNullable(fieldSchema.getIsNullable());
+        if (!ParamUtils.isVectorDataType(dType) && !fieldSchema.getIsPrimaryKey()) {
+            ValueField value = ParamUtils.objectToValueField(fieldSchema.getDefaultValue(), dType);
+            if (value != null) {
+                builder.setDefaultValue(value);
+            } else if (fieldSchema.getDefaultValue() != null) {
+                String msg = String.format("Illegal default value for %s type field. Please use Short for Int8/Int16 fields, " +
+                        "Short/Integer for Int32 fields, Short/Integer/Long for Int64 fields, Boolean for Bool fields, " +
+                        "String for Varchar fields, JsonObject for JSON fields.", dType.name());
+                throw new MilvusClientException(ErrorCode.INVALID_PARAMS, msg);
+            }
+        }
+
+        FieldSchema schema = builder.build();
         if(fieldSchema.getDimension() != null){
             schema = schema.toBuilder().addTypeParams(KeyValuePair.newBuilder().setKey("dim").setValue(String.valueOf(fieldSchema.getDimension())).build()).build();
         }
@@ -77,6 +91,8 @@ public class SchemaUtils {
                 .isPartitionKey(fieldSchema.getIsPartitionKey())
                 .autoID(fieldSchema.getAutoID())
                 .elementType(io.milvus.v2.common.DataType.valueOf(fieldSchema.getElementType().name()))
+                .isNullable(fieldSchema.getNullable())
+                .defaultValue(ParamUtils.valueFieldToObject(fieldSchema.getDefaultValue(), fieldSchema.getDataType()))
                 .build();
         for (KeyValuePair keyValuePair : fieldSchema.getTypeParamsList()) {
             if(keyValuePair.getKey().equals("dim")){

+ 169 - 2
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -84,7 +84,7 @@ class MilvusClientDockerTest {
     private static final Random RANDOM = new Random();
 
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.11");
+    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:master-20240927-1f271e39-amd64");
 
     @BeforeAll
     public static void setUp() {
@@ -2461,7 +2461,8 @@ class MilvusClientDockerTest {
             Object name = record.get(DataType.VarChar.name());
             Assertions.assertNotNull(name);
             Assertions.assertEquals("name_18", name);
-            Assertions.assertThrows(ParamException.class, () -> record.get("dynamic_value")); // we didn't set dynamic_value for No.18 row
+            Assertions.assertFalse(record.contains("dynamic_value"));
+            Assertions.assertNull(record.get("dynamic_value")); // we didn't set dynamic_value for No.18 row
         }
 
         // upsert to change the no.5 and no.18 items
@@ -3136,4 +3137,170 @@ class MilvusClientDockerTest {
             Assertions.fail(e.getMessage());
         }
     }
+
+    @Test
+    void testNullableAndDefaultValue() {
+        String randomCollectionName = generator.generate(10);
+
+        CollectionSchemaParam.Builder builder = CollectionSchemaParam.newBuilder();
+        builder.addFieldType(FieldType.newBuilder()
+                .withPrimaryKey(true)
+                .withAutoID(false)
+                .withDataType(DataType.Int64)
+                .withName("id")
+                .build());
+        builder.addFieldType(FieldType.newBuilder()
+                .withDataType(DataType.FloatVector)
+                .withName("vector")
+                .withDimension(DIMENSION)
+                .build());
+        builder.addFieldType(FieldType.newBuilder()
+                .withDataType(DataType.Int32)
+                .withName("flag")
+                .withMaxLength(100)
+                .withDefaultValue(10)
+                .build());
+        builder.addFieldType(FieldType.newBuilder()
+                .withDataType(DataType.VarChar)
+                .withName("desc")
+                .withMaxLength(100)
+                .withNullable(true)
+                .build());
+        R<RpcStatus> createR = client.createCollection(CreateCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withSchema(builder.build())
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue());
+
+        // create index on scalar field
+        CreateIndexParam indexParam = CreateIndexParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFieldName("vector")
+                .withIndexType(IndexType.FLAT)
+                .withMetricType(MetricType.L2)
+                .build();
+
+        R<RpcStatus> createIndexR = client.createIndex(indexParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue());
+
+        client.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .build());
+
+        // insert by row-based
+        List<JsonObject> data = new ArrayList<>();
+        Gson gson = new Gson();
+        for (int i = 0; i < 10; i++) {
+            JsonObject row = new JsonObject();
+            List<Float> vector = generateFloatVector();
+            row.addProperty("id", i);
+            row.add("vector", gson.toJsonTree(vector));
+            if (i%2 == 0) {
+                row.addProperty("flag", i);
+                row.add("desc", JsonNull.INSTANCE);
+            } else {
+                row.addProperty("desc", "AAA");
+            }
+            data.add(row);
+        }
+
+        R<MutationResult> insertR = client.insert(InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withRows(data)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
+
+        // insert by column-based
+        List<List<Float>> vectors = generateFloatVectors(10);
+        List<Long> ids = new ArrayList<>();
+        List<Integer> flags = new ArrayList<>();
+        List<String> descs = new ArrayList<>();
+        for (int i = 10; i < 20; i++) {
+            ids.add((long)i);
+            if (i%2 == 0) {
+                flags.add(i);
+                descs.add(null);
+            } else {
+                flags.add(null);
+                descs.add("AAA");
+            }
+
+        }
+        List<InsertParam.Field> fieldsInsert = new ArrayList<>();
+        fieldsInsert.add(new InsertParam.Field("id", ids));
+        fieldsInsert.add(new InsertParam.Field("vector", vectors));
+        fieldsInsert.add(new InsertParam.Field("flag", flags));
+        fieldsInsert.add(new InsertParam.Field("desc", descs));
+
+        InsertParam insertParam = InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFields(fieldsInsert)
+                .build();
+
+        insertR = client.insert(insertParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue());
+
+        // query
+        QueryParam queryParam = QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr("id >= 0")
+                .addOutField("flag")
+                .addOutField("desc")
+                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                .build();
+
+        R<QueryResults> queryR = client.query(queryParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), queryR.getStatus().intValue());
+
+        // verify query result
+        QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(queryR.getData());
+        List<QueryResultsWrapper.RowRecord> records = queryResultsWrapper.getRowRecords();
+        System.out.println("Query results:");
+        for (QueryResultsWrapper.RowRecord record:records) {
+            long id = (long)record.get("id");
+            if (id%2 == 0) {
+                Assertions.assertEquals((int)id, record.get("flag"));
+                Assertions.assertNull(record.get("desc"));
+            } else {
+                Assertions.assertEquals(10, record.get("flag"));
+                Assertions.assertEquals("AAA", record.get("desc"));
+            }
+            System.out.println(record);
+        }
+
+        // search the row-based items
+        List<List<Float>> searchVectors = generateFloatVectors(1);
+        SearchParam searchParam = SearchParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withMetricType(MetricType.L2)
+                .withTopK(10)
+                .withFloatVectors(searchVectors)
+                .withVectorFieldName("vector")
+                .withParams("{}")
+                .addOutField("flag")
+                .addOutField("desc")
+                .withConsistencyLevel(ConsistencyLevelEnum.BOUNDED)
+                .build();
+
+        R<SearchResults> searchR = client.search(searchParam);
+        Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue());
+
+        // verify the search result
+        SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
+        List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
+        System.out.println("Search results:");
+        Assertions.assertEquals(10, scores.size());
+        for (SearchResultsWrapper.IDScore score : scores) {
+            long id = score.getLongID();
+            Map<String, Object> fieldValues = score.getFieldValues();
+            if (id%2 == 0) {
+                Assertions.assertEquals((int)id, fieldValues.get("flag"));
+                Assertions.assertNull(fieldValues.get("desc"));
+            } else {
+                Assertions.assertEquals(10, fieldValues.get("flag"));
+                Assertions.assertEquals("AAA", fieldValues.get("desc"));
+            }
+            System.out.println(score);
+        }
+    }
 }

+ 120 - 1
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -80,7 +80,7 @@ class MilvusClientV2DockerTest {
     private static final Random RANDOM = new Random();
 
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.11");
+    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:master-20240927-1f271e39-amd64");
 
     @BeforeAll
     public static void setUp() {
@@ -1799,4 +1799,123 @@ class MilvusClientV2DockerTest {
             Assertions.fail(e.getMessage());
         }
     }
+
+    @Test
+    void testNullableAndDefaultValue() {
+        String randomCollectionName = generator.generate(10);
+        int dim = 4;
+
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("id")
+                .dataType(DataType.Int64)
+                .isPrimaryKey(Boolean.TRUE)
+                .autoID(Boolean.FALSE)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("vector")
+                .dataType(DataType.FloatVector)
+                .dimension(dim)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("flag")
+                .dataType(DataType.Int32)
+                .defaultValue((int)10)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName("desc")
+                .dataType(DataType.VarChar)
+                .isNullable(Boolean.TRUE)
+                .maxLength(100)
+                .build());
+
+        List<IndexParam> indexParams = new ArrayList<>();
+        indexParams.add(IndexParam.builder()
+                .fieldName("vector")
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.L2)
+                .build());
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexParams)
+                .build();
+        client.createCollection(requestCreate);
+        System.out.println("Collection created");
+
+        // insert by row-based
+        List<JsonObject> data = new ArrayList<>();
+        Gson gson = new Gson();
+        for (int i = 0; i < 10; i++) {
+            JsonObject row = new JsonObject();
+            List<Float> vector = generateFolatVector(dim);
+            row.addProperty("id", i);
+            row.add("vector", gson.toJsonTree(vector));
+            if (i%2 == 0) {
+                row.addProperty("flag", i);
+                row.add("desc", JsonNull.INSTANCE);
+            } else {
+//                row.add("flag", JsonNull.INSTANCE);
+                row.addProperty("desc", "AAA");
+            }
+            data.add(row);
+        }
+
+        InsertResp insertResp = client.insert(InsertReq.builder()
+                .collectionName(randomCollectionName)
+                .data(data)
+                .build());
+        Assertions.assertEquals(10, insertResp.getInsertCnt());
+
+        // query
+        QueryResp queryResp = client.query(QueryReq.builder()
+                .collectionName(randomCollectionName)
+                .filter("id >= 0")
+                .outputFields(Lists.newArrayList("*"))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+        Assertions.assertEquals(10, queryResults.size());
+        System.out.println("Query results:");
+        for (QueryResp.QueryResult result : queryResults) {
+            Map<String, Object> entity = result.getEntity();
+            long id = (long)entity.get("id");
+            if (id%2 == 0) {
+                Assertions.assertEquals((int)id, entity.get("flag"));
+                Assertions.assertNull(entity.get("desc"));
+            } else {
+                Assertions.assertEquals(10, entity.get("flag"));
+                Assertions.assertEquals("AAA", entity.get("desc"));
+            }
+            System.out.println(result);
+        }
+
+        // search
+        SearchResp searchResp = client.search(SearchReq.builder()
+                .collectionName(randomCollectionName)
+                .annsField("vector")
+                .data(Collections.singletonList(new FloatVec(generateFolatVector(dim))))
+                .topK(10)
+                .outputFields(Lists.newArrayList("*"))
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build());
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        Assertions.assertEquals(1, searchResults.size());
+        List<SearchResp.SearchResult> firstResults = searchResults.get(0);
+        Assertions.assertEquals(10, firstResults.size());
+        System.out.println("Search results:");
+        for (SearchResp.SearchResult result : firstResults) {
+            long id = (long)result.getId();
+            Map<String, Object> entity = result.getEntity();
+            if (id%2 == 0) {
+                Assertions.assertEquals((int)id, entity.get("flag"));
+                Assertions.assertNull(entity.get("desc"));
+            } else {
+                Assertions.assertEquals(10, entity.get("flag"));
+                Assertions.assertEquals("AAA", entity.get("desc"));
+            }
+            System.out.println(result);
+        }
+    }
 }