|
@@ -72,7 +72,7 @@ public class ParamUtils {
|
|
return typeErrMsg;
|
|
return typeErrMsg;
|
|
}
|
|
}
|
|
|
|
|
|
- private static HashMap<DataType, String> getTypeErrorMsgForRowInsert() {
|
|
|
|
|
|
+ public static HashMap<DataType, String> getTypeErrorMsgForRowInsert() {
|
|
final HashMap<DataType, String> typeErrMsg = new HashMap<>();
|
|
final HashMap<DataType, String> typeErrMsg = new HashMap<>();
|
|
typeErrMsg.put(DataType.None, "Type mismatch for field '%s': the field type is illegal.");
|
|
typeErrMsg.put(DataType.None, "Type mismatch for field '%s': the field type is illegal.");
|
|
typeErrMsg.put(DataType.Bool, "Type mismatch for field '%s': Bool field value type must be JsonPrimitive.");
|
|
typeErrMsg.put(DataType.Bool, "Type mismatch for field '%s': Bool field value type must be JsonPrimitive.");
|
|
@@ -99,7 +99,7 @@ public class ParamUtils {
|
|
checkFieldData(fieldSchema, values, false);
|
|
checkFieldData(fieldSchema, values, false);
|
|
}
|
|
}
|
|
|
|
|
|
- private static int calculateBinVectorDim(DataType dataType, int byteCount) {
|
|
|
|
|
|
+ public static int calculateBinVectorDim(DataType dataType, int byteCount) {
|
|
if (dataType == DataType.BinaryVector) {
|
|
if (dataType == DataType.BinaryVector) {
|
|
return byteCount*8; // for BinaryVector, each byte is 8 dimensions
|
|
return byteCount*8; // for BinaryVector, each byte is 8 dimensions
|
|
} else if (dataType == DataType.Int8Vector) {
|
|
} else if (dataType == DataType.Int8Vector) {
|
|
@@ -313,8 +313,8 @@ public class ParamUtils {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- public static Object checkFieldValue(FieldType fieldSchema, JsonElement value) {
|
|
|
|
- DataType dataType = fieldSchema.getDataType();
|
|
|
|
|
|
+ public static Object checkFieldValue(String fieldName, DataType dataType, DataType elementType, int dim, int maxLength,
|
|
|
|
+ int maxCapacity, boolean isNullable, Object defaultVal, JsonElement value) {
|
|
// nullable and default value check
|
|
// nullable and default value check
|
|
// 1. if the field is nullable, user can input JsonNull/JsonObject(for row-based insert)
|
|
// 1. if the field is nullable, user can input JsonNull/JsonObject(for row-based insert)
|
|
// 1) if user input JsonNull, this value is replaced by default value
|
|
// 1) if user input JsonNull, this value is replaced by default value
|
|
@@ -323,17 +323,17 @@ public class ParamUtils {
|
|
// 1) if user input JsonNull, and default value is null, throw error
|
|
// 1) if user input JsonNull, and default value is null, throw error
|
|
// 2) if user input JsonNull, and default value is not null, this value is replaced by default value
|
|
// 2) if user input JsonNull, and default value is not null, this value is replaced by default value
|
|
// 3) if user input JsonObject, infer this value by type
|
|
// 3) if user input JsonObject, infer this value by type
|
|
- if (fieldSchema.isNullable()) {
|
|
|
|
|
|
+ if (isNullable) {
|
|
if (value instanceof JsonNull) {
|
|
if (value instanceof JsonNull) {
|
|
- return fieldSchema.getDefaultValue(); // 1.1
|
|
|
|
|
|
+ return defaultVal; // 1.1
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
if (value instanceof JsonNull) {
|
|
if (value instanceof JsonNull) {
|
|
- if (fieldSchema.getDefaultValue() == null) {
|
|
|
|
|
|
+ if (defaultVal == null) {
|
|
String msg = "Field '%s' is not nullable but the input value is null";
|
|
String msg = "Field '%s' is not nullable but the input value is null";
|
|
- throw new ParamException(String.format(msg, fieldSchema.getName())); // 2.1
|
|
|
|
|
|
+ throw new ParamException(String.format(msg, fieldName)); // 2.1
|
|
} else {
|
|
} else {
|
|
- return fieldSchema.getDefaultValue(); // 2.2
|
|
|
|
|
|
+ return defaultVal; // 2.2
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -344,19 +344,18 @@ public class ParamUtils {
|
|
switch (dataType) {
|
|
switch (dataType) {
|
|
case FloatVector: {
|
|
case FloatVector: {
|
|
if (!(value.isJsonArray())) {
|
|
if (!(value.isJsonArray())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
- int dim = fieldSchema.getDimension();
|
|
|
|
try {
|
|
try {
|
|
List<Float> vector = JsonUtils.fromJson(value, new TypeToken<List<Float>>() {}.getType());
|
|
List<Float> vector = JsonUtils.fromJson(value, new TypeToken<List<Float>>() {}.getType());
|
|
if (vector.size() != dim) {
|
|
if (vector.size() != dim) {
|
|
String msg = "Incorrect dimension for field '%s': dimension: %d is not equal to field's dimension: %d";
|
|
String msg = "Incorrect dimension for field '%s': dimension: %d is not equal to field's dimension: %d";
|
|
- throw new ParamException(String.format(msg, fieldSchema.getName(), vector.size(), dim));
|
|
|
|
|
|
+ throw new ParamException(String.format(msg, fieldName, vector.size(), dim));
|
|
}
|
|
}
|
|
return vector; // return List<Float> for genFieldData()
|
|
return vector; // return List<Float> for genFieldData()
|
|
} catch (JsonSyntaxException e) {
|
|
} catch (JsonSyntaxException e) {
|
|
throw new ParamException(String.format("Unable to convert JsonArray to List<Float> for field '%s'. Reason: %s",
|
|
throw new ParamException(String.format("Unable to convert JsonArray to List<Float> for field '%s'. Reason: %s",
|
|
- fieldSchema.getName(), e.getCause().getMessage()));
|
|
|
|
|
|
+ fieldName, e.getCause().getMessage()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case BinaryVector:
|
|
case BinaryVector:
|
|
@@ -364,85 +363,84 @@ public class ParamUtils {
|
|
case BFloat16Vector:
|
|
case BFloat16Vector:
|
|
case Int8Vector: {
|
|
case Int8Vector: {
|
|
if (!(value.isJsonArray())) {
|
|
if (!(value.isJsonArray())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
- int dim = fieldSchema.getDimension();
|
|
|
|
try {
|
|
try {
|
|
byte[] v = JsonUtils.fromJson(value, new TypeToken<byte[]>() {}.getType());
|
|
byte[] v = JsonUtils.fromJson(value, new TypeToken<byte[]>() {}.getType());
|
|
int real_dim = calculateBinVectorDim(dataType, v.length);
|
|
int real_dim = calculateBinVectorDim(dataType, v.length);
|
|
if (real_dim != dim) {
|
|
if (real_dim != dim) {
|
|
String msg = "Incorrect dimension for field '%s': dimension: %d is not equal to field's dimension: %d";
|
|
String msg = "Incorrect dimension for field '%s': dimension: %d is not equal to field's dimension: %d";
|
|
- throw new ParamException(String.format(msg, fieldSchema.getName(), real_dim, dim));
|
|
|
|
|
|
+ throw new ParamException(String.format(msg, fieldName, real_dim, dim));
|
|
}
|
|
}
|
|
return ByteBuffer.wrap(v); // return ByteBuffer for genFieldData()
|
|
return ByteBuffer.wrap(v); // return ByteBuffer for genFieldData()
|
|
} catch (JsonSyntaxException e) {
|
|
} catch (JsonSyntaxException e) {
|
|
throw new ParamException(String.format("Unable to convert JsonArray to List<Float> for field '%s'. Reason: %s",
|
|
throw new ParamException(String.format("Unable to convert JsonArray to List<Float> for field '%s'. Reason: %s",
|
|
- fieldSchema.getName(), e.getCause().getMessage()));
|
|
|
|
|
|
+ fieldName, e.getCause().getMessage()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case SparseFloatVector:
|
|
case SparseFloatVector:
|
|
if (!(value.isJsonObject())) {
|
|
if (!(value.isJsonObject())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
try {
|
|
try {
|
|
// return SortedMap<Long, Float> for genFieldData()
|
|
// return SortedMap<Long, Float> for genFieldData()
|
|
return JsonUtils.fromJson(value, new TypeToken<SortedMap<Long, Float>>() {}.getType());
|
|
return JsonUtils.fromJson(value, new TypeToken<SortedMap<Long, Float>>() {}.getType());
|
|
} catch (JsonSyntaxException e) {
|
|
} catch (JsonSyntaxException e) {
|
|
throw new ParamException(String.format("Unable to convert JsonObject to SortedMap<Long, Float> for field '%s'. Reason: %s",
|
|
throw new ParamException(String.format("Unable to convert JsonObject to SortedMap<Long, Float> for field '%s'. Reason: %s",
|
|
- fieldSchema.getName(), e.getCause().getMessage()));
|
|
|
|
|
|
+ fieldName, e.getCause().getMessage()));
|
|
}
|
|
}
|
|
case Int64:
|
|
case Int64:
|
|
if (!(value.isJsonPrimitive())) {
|
|
if (!(value.isJsonPrimitive())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return value.getAsLong(); // return long for genFieldData()
|
|
return value.getAsLong(); // return long for genFieldData()
|
|
case Int32:
|
|
case Int32:
|
|
case Int16:
|
|
case Int16:
|
|
case Int8:
|
|
case Int8:
|
|
if (!(value.isJsonPrimitive())) {
|
|
if (!(value.isJsonPrimitive())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return value.getAsInt(); // return int for genFieldData()
|
|
return value.getAsInt(); // return int for genFieldData()
|
|
case Bool:
|
|
case Bool:
|
|
if (!(value.isJsonPrimitive())) {
|
|
if (!(value.isJsonPrimitive())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return value.getAsBoolean(); // return boolean for genFieldData()
|
|
return value.getAsBoolean(); // return boolean for genFieldData()
|
|
case Float:
|
|
case Float:
|
|
if (!(value.isJsonPrimitive())) {
|
|
if (!(value.isJsonPrimitive())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return value.getAsFloat(); // return float for genFieldData()
|
|
return value.getAsFloat(); // return float for genFieldData()
|
|
case Double:
|
|
case Double:
|
|
if (!(value.isJsonPrimitive())) {
|
|
if (!(value.isJsonPrimitive())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return value.getAsDouble(); // return double for genFieldData()
|
|
return value.getAsDouble(); // return double for genFieldData()
|
|
case VarChar:
|
|
case VarChar:
|
|
case String:
|
|
case String:
|
|
if (!(value.isJsonPrimitive())) {
|
|
if (!(value.isJsonPrimitive())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
JsonPrimitive p = value.getAsJsonPrimitive();
|
|
JsonPrimitive p = value.getAsJsonPrimitive();
|
|
if (!p.isString()) {
|
|
if (!p.isString()) {
|
|
- throw new ParamException(String.format("JsonPrimitive should be String type for field '%s'", fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format("JsonPrimitive should be String type for field '%s'", fieldName));
|
|
}
|
|
}
|
|
|
|
|
|
String str = p.getAsString();
|
|
String str = p.getAsString();
|
|
- if (str.length() > fieldSchema.getMaxLength()) {
|
|
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ if (str.length() > maxLength) {
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return str; // return String for genFieldData()
|
|
return str; // return String for genFieldData()
|
|
case JSON:
|
|
case JSON:
|
|
return value; // return JsonElement for genFieldData()
|
|
return value; // return JsonElement for genFieldData()
|
|
case Array:
|
|
case Array:
|
|
if (!(value.isJsonArray())) {
|
|
if (!(value.isJsonArray())) {
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
|
|
|
|
- List<Object> array = convertJsonArray(value.getAsJsonArray(), fieldSchema.getElementType(), fieldSchema.getName());
|
|
|
|
- if (array.size() > fieldSchema.getMaxCapacity()) {
|
|
|
|
- throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
|
|
+ List<Object> array = convertJsonArray(value.getAsJsonArray(), elementType, fieldName);
|
|
|
|
+ if (array.size() > maxCapacity) {
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
|
|
}
|
|
}
|
|
return array; // return List<Object> for genFieldData()
|
|
return array; // return List<Object> for genFieldData()
|
|
default:
|
|
default:
|
|
@@ -603,7 +601,7 @@ public class ParamUtils {
|
|
checkFieldData(fieldType, field);
|
|
checkFieldData(fieldType, field);
|
|
|
|
|
|
found = true;
|
|
found = true;
|
|
- this.addFieldsData(genFieldData(fieldType, field.getValues()));
|
|
|
|
|
|
+ this.addFieldsData(genFieldData(fieldType, field.getValues(), false));
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -668,7 +666,9 @@ public class ParamUtils {
|
|
String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
|
|
String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
|
|
throw new ParamException(msg);
|
|
throw new ParamException(msg);
|
|
}
|
|
}
|
|
- Object fieldValue = checkFieldValue(fieldType, rowFieldData);
|
|
|
|
|
|
+ Object fieldValue = checkFieldValue(fieldType.getName(), fieldType.getDataType(),
|
|
|
|
+ fieldType.getElementType(), fieldType.getDimension(), fieldType.getMaxLength(),
|
|
|
|
+ fieldType.getMaxCapacity(), fieldType.isNullable(), fieldType.getDefaultValue(), rowFieldData);
|
|
insertDataInfo.getData().add(fieldValue);
|
|
insertDataInfo.getData().add(fieldValue);
|
|
nameInsertInfo.put(fieldName, insertDataInfo);
|
|
nameInsertInfo.put(fieldName, insertDataInfo);
|
|
}
|
|
}
|
|
@@ -687,10 +687,10 @@ public class ParamUtils {
|
|
|
|
|
|
for (String fieldNameKey : nameInsertInfo.keySet()) {
|
|
for (String fieldNameKey : nameInsertInfo.keySet()) {
|
|
InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey);
|
|
InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey);
|
|
- this.addFieldsData(genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData()));
|
|
|
|
|
|
+ this.addFieldsData(genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData(), false));
|
|
}
|
|
}
|
|
if (wrapper.getEnableDynamicField()) {
|
|
if (wrapper.getEnableDynamicField()) {
|
|
- this.addFieldsData(genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), Boolean.TRUE));
|
|
|
|
|
|
+ this.addFieldsData(genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), true));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -1161,23 +1161,23 @@ public class ParamUtils {
|
|
return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector;
|
|
return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector;
|
|
}
|
|
}
|
|
|
|
|
|
- public static FieldData genFieldData(FieldType fieldType, List<?> objects) {
|
|
|
|
- return genFieldData(fieldType, objects, Boolean.FALSE);
|
|
|
|
|
|
+ private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
|
|
|
|
+ return genFieldData(fieldType.getName(), fieldType.getDataType(), fieldType.getElementType(),
|
|
|
|
+ fieldType.isNullable(), fieldType.getDefaultValue(), objects, isDynamic);
|
|
}
|
|
}
|
|
|
|
|
|
- public static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
|
|
|
|
|
|
+ public static FieldData genFieldData(String fieldName, DataType dataType, DataType elementType, boolean isNullable,
|
|
|
|
+ Object defaultVal, List<?> objects, boolean isDynamic) {
|
|
if (objects == null) {
|
|
if (objects == null) {
|
|
throw new ParamException("Cannot generate FieldData from null object");
|
|
throw new ParamException("Cannot generate FieldData from null object");
|
|
}
|
|
}
|
|
- DataType dataType = fieldType.getDataType();
|
|
|
|
- String fieldName = fieldType.getName();
|
|
|
|
- FieldData.Builder builder = FieldData.newBuilder();
|
|
|
|
|
|
|
|
|
|
+ FieldData.Builder builder = FieldData.newBuilder();
|
|
if (isVectorDataType(dataType)) {
|
|
if (isVectorDataType(dataType)) {
|
|
VectorField vectorField = genVectorField(dataType, objects);
|
|
VectorField vectorField = genVectorField(dataType, objects);
|
|
return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
|
|
return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
|
|
} else {
|
|
} else {
|
|
- if (fieldType.isNullable() || fieldType.getDefaultValue() != null) {
|
|
|
|
|
|
+ if (isNullable || defaultVal != null) {
|
|
List<Object> tempObjects = new ArrayList<>();
|
|
List<Object> tempObjects = new ArrayList<>();
|
|
for (Object obj : objects) {
|
|
for (Object obj : objects) {
|
|
builder.addValidData(obj != null);
|
|
builder.addValidData(obj != null);
|
|
@@ -1188,7 +1188,7 @@ public class ParamUtils {
|
|
objects = tempObjects;
|
|
objects = tempObjects;
|
|
}
|
|
}
|
|
|
|
|
|
- ScalarField scalarField = genScalarField(fieldType, objects);
|
|
|
|
|
|
+ ScalarField scalarField = genScalarField(dataType, elementType, objects);
|
|
if (isDynamic) {
|
|
if (isDynamic) {
|
|
return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build();
|
|
return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build();
|
|
}
|
|
}
|
|
@@ -1197,7 +1197,7 @@ public class ParamUtils {
|
|
}
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
@SuppressWarnings("unchecked")
|
|
- private static VectorField genVectorField(DataType dataType, List<?> objects) {
|
|
|
|
|
|
+ public static VectorField genVectorField(DataType dataType, List<?> objects) {
|
|
if (dataType == DataType.FloatVector) {
|
|
if (dataType == DataType.FloatVector) {
|
|
List<Float> floats = new ArrayList<>();
|
|
List<Float> floats = new ArrayList<>();
|
|
// each object is List<Float>
|
|
// each object is List<Float>
|
|
@@ -1323,22 +1323,18 @@ public class ParamUtils {
|
|
return builder.setDim(dim).build();
|
|
return builder.setDim(dim).build();
|
|
}
|
|
}
|
|
|
|
|
|
- private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
|
|
|
|
- if (fieldType.getDataType() == DataType.Array) {
|
|
|
|
|
|
+ public static ScalarField genScalarField(DataType dataType, DataType elementType, List<?> objects) {
|
|
|
|
+ if (dataType == DataType.Array) {
|
|
ArrayArray.Builder builder = ArrayArray.newBuilder();
|
|
ArrayArray.Builder builder = ArrayArray.newBuilder();
|
|
for (Object object : objects) {
|
|
for (Object object : objects) {
|
|
List<?> temp = (List<?>)object;
|
|
List<?> temp = (List<?>)object;
|
|
- ScalarField arrayField = genScalarField(fieldType.getElementType(), temp);
|
|
|
|
|
|
+ ScalarField arrayField = genScalarField(elementType, DataType.None, temp);
|
|
builder.addData(arrayField);
|
|
builder.addData(arrayField);
|
|
}
|
|
}
|
|
|
|
|
|
return ScalarField.newBuilder().setArrayData(builder.build()).build();
|
|
return ScalarField.newBuilder().setArrayData(builder.build()).build();
|
|
- } else {
|
|
|
|
- return genScalarField(fieldType.getDataType(), objects);
|
|
|
|
}
|
|
}
|
|
- }
|
|
|
|
|
|
|
|
- private static ScalarField genScalarField(DataType dataType, List<?> objects) {
|
|
|
|
switch (dataType) {
|
|
switch (dataType) {
|
|
case None:
|
|
case None:
|
|
case UNRECOGNIZED:
|
|
case UNRECOGNIZED:
|