|
@@ -44,6 +44,8 @@ public class ParamUtils {
|
|
typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be String");
|
|
typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be String");
|
|
typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be List<Float>");
|
|
typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be List<Float>");
|
|
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
|
|
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
|
|
|
|
+ typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
|
|
|
|
+ typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
|
|
return typeErrMsg;
|
|
return typeErrMsg;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -52,6 +54,18 @@ public class ParamUtils {
|
|
checkFieldData(fieldSchema, values, false);
|
|
checkFieldData(fieldSchema, values, false);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private static int calculateBinVectorDim(DataType dataType, int byteCount) {
|
|
|
|
+ if (dataType == DataType.BinaryVector) {
|
|
|
|
+ return byteCount*8; // for BinaryVector, each byte is 8 dimensions
|
|
|
|
+ } else {
|
|
|
|
+ if (byteCount%2 != 0) {
|
|
|
|
+ String msg = "Incorrect byte count for %s type field, byte count is %d, cannot be evenly divided by 2";
|
|
|
|
+ throw new ParamException(String.format(msg, dataType.name(), byteCount));
|
|
|
|
+ }
|
|
|
|
+ return byteCount/2; // for float16/bfloat16, each dimension is 2 bytes
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean verifyElementType) {
|
|
public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean verifyElementType) {
|
|
HashMap<DataType, String> errMsgs = getTypeErrorMsg();
|
|
HashMap<DataType, String> errMsgs = getTypeErrorMsg();
|
|
DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType();
|
|
DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType();
|
|
@@ -86,7 +100,10 @@ public class ParamUtils {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
break;
|
|
- case BinaryVector: {
|
|
|
|
|
|
+ case BinaryVector:
|
|
|
|
+ case Float16Vector:
|
|
|
|
+ case BFloat16Vector:
|
|
|
|
+ {
|
|
int dim = fieldSchema.getDimension();
|
|
int dim = fieldSchema.getDimension();
|
|
for (int i = 0; i < values.size(); ++i) {
|
|
for (int i = 0; i < values.size(); ++i) {
|
|
Object value = values.get(i);
|
|
Object value = values.get(i);
|
|
@@ -97,7 +114,8 @@ public class ParamUtils {
|
|
|
|
|
|
// check dimension
|
|
// check dimension
|
|
ByteBuffer v = (ByteBuffer)value;
|
|
ByteBuffer v = (ByteBuffer)value;
|
|
- if (v.position()*8 != dim) {
|
|
|
|
|
|
+ int real_dim = calculateBinVectorDim(dataType, v.position());
|
|
|
|
+ if (real_dim != dim) {
|
|
String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d";
|
|
String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d";
|
|
throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
|
|
throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
|
|
}
|
|
}
|
|
@@ -599,11 +617,15 @@ public class ParamUtils {
|
|
return guaranteeTimestamp;
|
|
return guaranteeTimestamp;
|
|
}
|
|
}
|
|
|
|
|
|
-
|
|
|
|
- private static final Set<DataType> vectorDataType = new HashSet<DataType>() {{
|
|
|
|
- add(DataType.FloatVector);
|
|
|
|
- add(DataType.BinaryVector);
|
|
|
|
- }};
|
|
|
|
|
|
+ public static boolean isVectorDataType(DataType dataType) {
|
|
|
|
+ Set<DataType> vectorDataType = new HashSet<DataType>() {{
|
|
|
|
+ add(DataType.FloatVector);
|
|
|
|
+ add(DataType.BinaryVector);
|
|
|
|
+ add(DataType.Float16Vector);
|
|
|
|
+ add(DataType.BFloat16Vector);
|
|
|
|
+ }};
|
|
|
|
+ return vectorDataType.contains(dataType);
|
|
|
|
+ }
|
|
|
|
|
|
private static FieldData genFieldData(FieldType fieldType, List<?> objects) {
|
|
private static FieldData genFieldData(FieldType fieldType, List<?> objects) {
|
|
return genFieldData(fieldType, objects, Boolean.FALSE);
|
|
return genFieldData(fieldType, objects, Boolean.FALSE);
|
|
@@ -617,7 +639,7 @@ public class ParamUtils {
|
|
DataType dataType = fieldType.getDataType();
|
|
DataType dataType = fieldType.getDataType();
|
|
String fieldName = fieldType.getName();
|
|
String fieldName = fieldType.getName();
|
|
FieldData.Builder builder = FieldData.newBuilder();
|
|
FieldData.Builder builder = FieldData.newBuilder();
|
|
- if (vectorDataType.contains(dataType)) {
|
|
|
|
|
|
+ if (isVectorDataType(dataType)) {
|
|
VectorField vectorField = genVectorField(dataType, objects);
|
|
VectorField vectorField = genVectorField(dataType, objects);
|
|
return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
|
|
return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
|
|
} else {
|
|
} else {
|
|
@@ -646,7 +668,9 @@ public class ParamUtils {
|
|
int dim = floats.size() / objects.size();
|
|
int dim = floats.size() / objects.size();
|
|
FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
|
|
return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
|
|
return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
|
|
- } else if (dataType == DataType.BinaryVector) {
|
|
|
|
|
|
+ } else if (dataType == DataType.BinaryVector ||
|
|
|
|
+ dataType == DataType.Float16Vector ||
|
|
|
|
+ dataType == DataType.BFloat16Vector) {
|
|
ByteBuffer totalBuf = null;
|
|
ByteBuffer totalBuf = null;
|
|
int dim = 0;
|
|
int dim = 0;
|
|
// each object is ByteBuffer
|
|
// each object is ByteBuffer
|
|
@@ -655,7 +679,7 @@ public class ParamUtils {
|
|
if (totalBuf == null) {
|
|
if (totalBuf == null) {
|
|
totalBuf = ByteBuffer.allocate(buf.position() * objects.size());
|
|
totalBuf = ByteBuffer.allocate(buf.position() * objects.size());
|
|
totalBuf.put(buf.array());
|
|
totalBuf.put(buf.array());
|
|
- dim = buf.position() * 8;
|
|
|
|
|
|
+ dim = calculateBinVectorDim(dataType, buf.position());
|
|
} else {
|
|
} else {
|
|
totalBuf.put(buf.array());
|
|
totalBuf.put(buf.array());
|
|
}
|
|
}
|
|
@@ -663,7 +687,13 @@ public class ParamUtils {
|
|
|
|
|
|
assert totalBuf != null;
|
|
assert totalBuf != null;
|
|
ByteString byteString = ByteString.copyFrom(totalBuf.array());
|
|
ByteString byteString = ByteString.copyFrom(totalBuf.array());
|
|
- return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
|
|
|
|
|
|
+ if (dataType == DataType.BinaryVector) {
|
|
|
|
+ return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
|
|
|
|
+ } else if (dataType == DataType.Float16Vector) {
|
|
|
|
+ return VectorField.newBuilder().setDim(dim).setFloat16Vector(byteString).build();
|
|
|
|
+ } else {
|
|
|
|
+ return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
throw new ParamException("Illegal vector dataType:" + dataType);
|
|
throw new ParamException("Illegal vector dataType:" + dataType);
|