|
@@ -49,12 +49,17 @@ public class ParamUtils {
|
|
|
|
|
|
private static void checkFieldData(FieldType fieldSchema, InsertParam.Field fieldData) {
|
|
|
List<?> values = fieldData.getValues();
|
|
|
- checkFieldData(fieldSchema, values);
|
|
|
+ checkFieldData(fieldSchema, values, false);
|
|
|
}
|
|
|
|
|
|
- private static void checkFieldData(FieldType fieldSchema, List<?> values) {
|
|
|
+ private static void checkFieldData(FieldType fieldSchema, List<?> values, boolean verifyElementType) {
|
|
|
HashMap<DataType, String> errMsgs = getTypeErrorMsg();
|
|
|
- DataType dataType = fieldSchema.getDataType();
|
|
|
+ DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType();
|
|
|
+
|
|
|
+ if (verifyElementType && values.size() > fieldSchema.getMaxCapacity()) {
|
|
|
+ throw new ParamException(String.format("Array field '%s' length: %d exceeds max capacity: %d",
|
|
|
+ fieldSchema.getName(), values.size(), fieldSchema.getMaxCapacity()));
|
|
|
+ }
|
|
|
|
|
|
switch (dataType) {
|
|
|
case FloatVector: {
|
|
@@ -151,6 +156,16 @@ public class ParamUtils {
|
|
|
}
|
|
|
}
|
|
|
break;
|
|
|
+ case Array:
|
|
|
+ for (Object value : values) {
|
|
|
+ if (!(value instanceof List)) {
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
+ }
|
|
|
+
|
|
|
+ List<?> temp = (List<?>)value;
|
|
|
+ checkFieldData(fieldSchema, temp, true);
|
|
|
+ }
|
|
|
+ break;
|
|
|
default:
|
|
|
throw new IllegalResponseException("Unsupported data type returned by FieldData");
|
|
|
}
|
|
@@ -330,7 +345,7 @@ public class ParamUtils {
|
|
|
checkFieldData(fieldType, field);
|
|
|
|
|
|
found = true;
|
|
|
- this.addFieldsData(genFieldData(field.getName(), fieldType.getDataType(), field.getValues()));
|
|
|
+ this.addFieldsData(genFieldData(fieldType, field.getValues()));
|
|
|
break;
|
|
|
}
|
|
|
|
|
@@ -346,12 +361,18 @@ public class ParamUtils {
|
|
|
List<FieldType> fieldTypes = wrapper.getFields();
|
|
|
|
|
|
Map<String, InsertDataInfo> nameInsertInfo = new HashMap<>();
|
|
|
- InsertDataInfo insertDynamicDataInfo = InsertDataInfo.builder().dataType(DataType.JSON).data(new LinkedList<>()).build();
|
|
|
+ InsertDataInfo insertDynamicDataInfo = InsertDataInfo.builder().fieldType(
|
|
|
+ FieldType.newBuilder()
|
|
|
+ .withName(Constant.DYNAMIC_FIELD_NAME)
|
|
|
+ .withDataType(DataType.JSON)
|
|
|
+ .withIsDynamic(true)
|
|
|
+ .build())
|
|
|
+ .data(new LinkedList<>()).build();
|
|
|
for (JSONObject row : rows) {
|
|
|
for (FieldType fieldType : fieldTypes) {
|
|
|
String fieldName = fieldType.getName();
|
|
|
InsertDataInfo insertDataInfo = nameInsertInfo.getOrDefault(fieldName, InsertDataInfo.builder()
|
|
|
- .fieldName(fieldName).dataType(fieldType.getDataType()).data(new LinkedList<>()).build());
|
|
|
+ .fieldType(fieldType).data(new LinkedList<>()).build());
|
|
|
|
|
|
// check normalField
|
|
|
Object rowFieldData = row.get(fieldName);
|
|
@@ -360,7 +381,7 @@ public class ParamUtils {
|
|
|
String msg = "The primary key: " + fieldName + " is auto generated, no need to input.";
|
|
|
throw new ParamException(msg);
|
|
|
}
|
|
|
- checkFieldData(fieldType, Lists.newArrayList(rowFieldData));
|
|
|
+ checkFieldData(fieldType, Lists.newArrayList(rowFieldData), false);
|
|
|
|
|
|
insertDataInfo.getData().add(rowFieldData);
|
|
|
nameInsertInfo.put(fieldName, insertDataInfo);
|
|
@@ -387,10 +408,10 @@ public class ParamUtils {
|
|
|
|
|
|
for (String fieldNameKey : nameInsertInfo.keySet()) {
|
|
|
InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey);
|
|
|
- this.addFieldsData(genFieldData(insertDataInfo.getFieldName(), insertDataInfo.getDataType(), insertDataInfo.getData()));
|
|
|
+ this.addFieldsData(genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData()));
|
|
|
}
|
|
|
if (wrapper.getEnableDynamicField()) {
|
|
|
- this.addFieldsData(genFieldData(insertDynamicDataInfo.getFieldName(), insertDynamicDataInfo.getDataType(), insertDynamicDataInfo.getData(), Boolean.TRUE));
|
|
|
+ this.addFieldsData(genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), Boolean.TRUE));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -601,111 +622,132 @@ public class ParamUtils {
|
|
|
add(DataType.BinaryVector);
|
|
|
}};
|
|
|
|
|
|
- private static FieldData genFieldData(String fieldName, DataType dataType, List<?> objects) {
|
|
|
- return genFieldData(fieldName, dataType, objects, Boolean.FALSE);
|
|
|
+ private static FieldData genFieldData(FieldType fieldType, List<?> objects) {
|
|
|
+ return genFieldData(fieldType, objects, Boolean.FALSE);
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- private static FieldData genFieldData(String fieldName, DataType dataType, List<?> objects, boolean isDynamic) {
|
|
|
+ private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
|
|
|
if (objects == null) {
|
|
|
throw new ParamException("Cannot generate FieldData from null object");
|
|
|
}
|
|
|
+ DataType dataType = fieldType.getDataType();
|
|
|
+ String fieldName = fieldType.getName();
|
|
|
FieldData.Builder builder = FieldData.newBuilder();
|
|
|
if (vectorDataType.contains(dataType)) {
|
|
|
- if (dataType == DataType.FloatVector) {
|
|
|
- List<Float> floats = new ArrayList<>();
|
|
|
- // each object is List<Float>
|
|
|
- for (Object object : objects) {
|
|
|
- if (object instanceof List) {
|
|
|
- List<Float> list = (List<Float>) object;
|
|
|
- floats.addAll(list);
|
|
|
- } else {
|
|
|
- throw new ParamException("The type of FloatVector must be List<Float>");
|
|
|
- }
|
|
|
+ VectorField vectorField = genVectorField(dataType, objects);
|
|
|
+ return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
|
|
|
+ } else {
|
|
|
+ ScalarField scalarField = genScalarField(fieldType, objects);
|
|
|
+ if (isDynamic) {
|
|
|
+ return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build();
|
|
|
+ }
|
|
|
+ return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ private static VectorField genVectorField(DataType dataType, List<?> objects) {
|
|
|
+ if (dataType == DataType.FloatVector) {
|
|
|
+ List<Float> floats = new ArrayList<>();
|
|
|
+ // each object is List<Float>
|
|
|
+ for (Object object : objects) {
|
|
|
+ if (object instanceof List) {
|
|
|
+ List<Float> list = (List<Float>) object;
|
|
|
+ floats.addAll(list);
|
|
|
+ } else {
|
|
|
+ throw new ParamException("The type of FloatVector must be List<Float>");
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- int dim = floats.size() / objects.size();
|
|
|
- FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
|
- VectorField vectorField = VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(DataType.FloatVector).setVectors(vectorField).build();
|
|
|
- } else if (dataType == DataType.BinaryVector) {
|
|
|
- ByteBuffer totalBuf = null;
|
|
|
- int dim = 0;
|
|
|
- // each object is ByteBuffer
|
|
|
- for (Object object : objects) {
|
|
|
- ByteBuffer buf = (ByteBuffer) object;
|
|
|
- if (totalBuf == null) {
|
|
|
- totalBuf = ByteBuffer.allocate(buf.position() * objects.size());
|
|
|
- totalBuf.put(buf.array());
|
|
|
- dim = buf.position() * 8;
|
|
|
- } else {
|
|
|
- totalBuf.put(buf.array());
|
|
|
- }
|
|
|
+ int dim = floats.size() / objects.size();
|
|
|
+ FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
|
+ return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
|
|
|
+ } else if (dataType == DataType.BinaryVector) {
|
|
|
+ ByteBuffer totalBuf = null;
|
|
|
+ int dim = 0;
|
|
|
+ // each object is ByteBuffer
|
|
|
+ for (Object object : objects) {
|
|
|
+ ByteBuffer buf = (ByteBuffer) object;
|
|
|
+ if (totalBuf == null) {
|
|
|
+ totalBuf = ByteBuffer.allocate(buf.position() * objects.size());
|
|
|
+ totalBuf.put(buf.array());
|
|
|
+ dim = buf.position() * 8;
|
|
|
+ } else {
|
|
|
+ totalBuf.put(buf.array());
|
|
|
}
|
|
|
+ }
|
|
|
+
|
|
|
+ assert totalBuf != null;
|
|
|
+ ByteString byteString = ByteString.copyFrom(totalBuf.array());
|
|
|
+ return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
|
|
|
+ }
|
|
|
|
|
|
- assert totalBuf != null;
|
|
|
- ByteString byteString = ByteString.copyFrom(totalBuf.array());
|
|
|
- VectorField vectorField = VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
|
|
|
- return builder.setFieldName(fieldName).setType(DataType.BinaryVector).setVectors(vectorField).build();
|
|
|
+ throw new ParamException("Illegal vector dataType:" + dataType);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
|
|
|
+ if (fieldType.getDataType() == DataType.Array) {
|
|
|
+ ArrayArray.Builder builder = ArrayArray.newBuilder();
|
|
|
+ for (Object object : objects) {
|
|
|
+ List<?> temp = (List<?>)object;
|
|
|
+ ScalarField arrayField = genScalarField(fieldType.getElementType(), temp);
|
|
|
+ builder.addData(arrayField);
|
|
|
}
|
|
|
+
|
|
|
+ return ScalarField.newBuilder().setArrayData(builder.build()).build();
|
|
|
} else {
|
|
|
- switch (dataType) {
|
|
|
- case None:
|
|
|
- case UNRECOGNIZED:
|
|
|
- throw new ParamException("Cannot support this dataType:" + dataType);
|
|
|
- case Int64: {
|
|
|
- List<Long> longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
|
|
|
- LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setLongData(longArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- case Int32:
|
|
|
- case Int16:
|
|
|
- case Int8: {
|
|
|
- List<Integer> integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList());
|
|
|
- IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setIntData(intArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- case Bool: {
|
|
|
- List<Boolean> booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
|
|
|
- BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setBoolData(boolArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- case Float: {
|
|
|
- List<Float> floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
|
|
|
- FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setFloatData(floatArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- case Double: {
|
|
|
- List<Double> doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
|
|
|
- DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setDoubleData(doubleArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- case String:
|
|
|
- case VarChar: {
|
|
|
- List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
|
|
|
- StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setStringData(stringArray).build();
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- case JSON: {
|
|
|
- List<ByteString> byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(((JSONObject) p).toJSONString()))
|
|
|
- .collect(Collectors.toList());
|
|
|
- JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build();
|
|
|
- ScalarField scalarField = ScalarField.newBuilder().setJsonData(jsonArray).build();
|
|
|
- if (isDynamic) {
|
|
|
- return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build();
|
|
|
- }
|
|
|
- return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
|
|
|
- }
|
|
|
- }
|
|
|
+ return genScalarField(fieldType.getDataType(), objects);
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- return null;
|
|
|
+ private static ScalarField genScalarField(DataType dataType, List<?> objects) {
|
|
|
+ switch (dataType) {
|
|
|
+ case None:
|
|
|
+ case UNRECOGNIZED:
|
|
|
+ throw new ParamException("Cannot support this dataType:" + dataType);
|
|
|
+ case Int64: {
|
|
|
+ List<Long> longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
|
|
|
+ LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
|
|
|
+ return ScalarField.newBuilder().setLongData(longArray).build();
|
|
|
+ }
|
|
|
+ case Int32:
|
|
|
+ case Int16:
|
|
|
+ case Int8: {
|
|
|
+ List<Integer> integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList());
|
|
|
+ IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
|
|
|
+ return ScalarField.newBuilder().setIntData(intArray).build();
|
|
|
+ }
|
|
|
+ case Bool: {
|
|
|
+ List<Boolean> booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
|
|
|
+ BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
|
|
|
+ return ScalarField.newBuilder().setBoolData(boolArray).build();
|
|
|
+ }
|
|
|
+ case Float: {
|
|
|
+ List<Float> floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
|
|
|
+ FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
|
+ return ScalarField.newBuilder().setFloatData(floatArray).build();
|
|
|
+ }
|
|
|
+ case Double: {
|
|
|
+ List<Double> doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
|
|
|
+ DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
|
|
|
+ return ScalarField.newBuilder().setDoubleData(doubleArray).build();
|
|
|
+ }
|
|
|
+ case String:
|
|
|
+ case VarChar: {
|
|
|
+ List<String> strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
|
|
|
+ StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
|
|
|
+ return ScalarField.newBuilder().setStringData(stringArray).build();
|
|
|
+ }
|
|
|
+ case JSON: {
|
|
|
+ List<ByteString> byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(((JSONObject) p).toJSONString()))
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build();
|
|
|
+ return ScalarField.newBuilder().setJsonData(jsonArray).build();
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ throw new ParamException("Illegal scalar dataType:" + dataType);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -722,6 +764,7 @@ public class ParamUtils {
|
|
|
.withPartitionKey(field.getIsPartitionKey())
|
|
|
.withAutoID(field.getAutoID())
|
|
|
.withDataType(field.getDataType())
|
|
|
+ .withElementType(field.getElementType())
|
|
|
.withIsDynamic(field.getIsDynamic());
|
|
|
|
|
|
if (field.getIsDynamic()) {
|
|
@@ -748,6 +791,7 @@ public class ParamUtils {
|
|
|
.setIsPartitionKey(field.isPartitionKey())
|
|
|
.setAutoID(field.isAutoID())
|
|
|
.setDataType(field.getDataType())
|
|
|
+ .setElementType(field.getElementType())
|
|
|
.setIsDynamic(field.isDynamic());
|
|
|
|
|
|
// assemble typeParams for CollectionSchema
|
|
@@ -776,8 +820,7 @@ public class ParamUtils {
|
|
|
@Builder
|
|
|
@Getter
|
|
|
public static class InsertDataInfo {
|
|
|
- private final String fieldName;
|
|
|
- private final DataType dataType;
|
|
|
+ private final FieldType fieldType;
|
|
|
private final LinkedList<Object> data;
|
|
|
}
|
|
|
}
|