|
@@ -46,6 +46,7 @@ public class ParamUtils {
|
|
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
|
|
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
|
|
typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
|
|
typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
|
|
typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
|
|
typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
|
|
|
|
+ typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap");
|
|
return typeErrMsg;
|
|
return typeErrMsg;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -98,12 +99,11 @@ public class ParamUtils {
|
|
throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
|
|
throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ break;
|
|
}
|
|
}
|
|
- break;
|
|
|
|
case BinaryVector:
|
|
case BinaryVector:
|
|
case Float16Vector:
|
|
case Float16Vector:
|
|
- case BFloat16Vector:
|
|
|
|
- {
|
|
|
|
|
|
+ case BFloat16Vector: {
|
|
int dim = fieldSchema.getDimension();
|
|
int dim = fieldSchema.getDimension();
|
|
for (int i = 0; i < values.size(); ++i) {
|
|
for (int i = 0; i < values.size(); ++i) {
|
|
Object value = values.get(i);
|
|
Object value = values.get(i);
|
|
@@ -120,8 +120,30 @@ public class ParamUtils {
|
|
throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
|
|
throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ break;
|
|
}
|
|
}
|
|
- break;
|
|
|
|
|
|
+ case SparseFloatVector:
|
|
|
|
+ for (Object value : values) {
|
|
|
|
+ if (!(value instanceof SortedMap)) {
|
|
|
|
+ throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // is SortedMap<Long, Float> ?
|
|
|
|
+ SortedMap<?, ?> m = (SortedMap<?, ?>)value;
|
|
|
|
+ if (m.isEmpty()) { // not allow empty value for sparse vector
|
|
|
|
+ String msg = "Not allow empty SortedMap for sparse vector field '%s'";
|
|
|
|
+ throw new ParamException(String.format(msg, fieldSchema.getName()));
|
|
|
|
+ }
|
|
|
|
+ if (!(m.firstKey() instanceof Long)) {
|
|
|
|
+ String msg = "The key of SortedMap must be Long for sparse vector field '%s'";
|
|
|
|
+ throw new ParamException(String.format(msg, fieldSchema.getName()));
|
|
|
|
+ }
|
|
|
|
+ if (!(m.get(m.firstKey()) instanceof Float)) {
|
|
|
|
+ String msg = "The value of SortedMap must be Float for sparse vector field '%s'";
|
|
|
|
+ throw new ParamException(String.format(msg, fieldSchema.getName()));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ break;
|
|
case Int64:
|
|
case Int64:
|
|
for (Object value : values) {
|
|
for (Object value : values) {
|
|
if (!(value instanceof Long)) {
|
|
if (!(value instanceof Long)) {
|
|
@@ -286,7 +308,8 @@ public class ParamUtils {
|
|
}
|
|
}
|
|
if (isPartitionKeyEnabled) {
|
|
if (isPartitionKeyEnabled) {
|
|
if (partitionName != null && !partitionName.isEmpty()) {
|
|
if (partitionName != null && !partitionName.isEmpty()) {
|
|
- String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name";
|
|
|
|
|
|
+ String msg = String.format("Collection %s has partition key, not allow to specify partition name",
|
|
|
|
+ requestParam.getCollectionName());
|
|
throw new ParamException(msg);
|
|
throw new ParamException(msg);
|
|
}
|
|
}
|
|
} else if (partitionName != null) {
|
|
} else if (partitionName != null) {
|
|
@@ -314,7 +337,8 @@ public class ParamUtils {
|
|
for (InsertParam.Field field : fields) {
|
|
for (InsertParam.Field field : fields) {
|
|
if (field.getName().equals(fieldType.getName())) {
|
|
if (field.getName().equals(fieldType.getName())) {
|
|
if (fieldType.isAutoID()) {
|
|
if (fieldType.isAutoID()) {
|
|
- String msg = "The primary key: " + fieldType.getName() + " is auto generated, no need to input.";
|
|
|
|
|
|
+ String msg = String.format("The primary key: %s is auto generated, no need to input.",
|
|
|
|
+ fieldType.getName());
|
|
throw new ParamException(msg);
|
|
throw new ParamException(msg);
|
|
}
|
|
}
|
|
checkFieldData(fieldType, field);
|
|
checkFieldData(fieldType, field);
|
|
@@ -326,8 +350,7 @@ public class ParamUtils {
|
|
|
|
|
|
}
|
|
}
|
|
if (!found && !fieldType.isAutoID()) {
|
|
if (!found && !fieldType.isAutoID()) {
|
|
- String msg = "The field: " + fieldType.getName() + " is not provided.";
|
|
|
|
- throw new ParamException(msg);
|
|
|
|
|
|
+ throw new ParamException(String.format("The field: %s is not provided.", fieldType.getName()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -369,7 +392,7 @@ public class ParamUtils {
|
|
Object rowFieldData = row.get(fieldName);
|
|
Object rowFieldData = row.get(fieldName);
|
|
if (rowFieldData != null) {
|
|
if (rowFieldData != null) {
|
|
if (fieldType.isAutoID()) {
|
|
if (fieldType.isAutoID()) {
|
|
- String msg = "The primary key: " + fieldName + " is auto generated, no need to input.";
|
|
|
|
|
|
+ String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
|
|
throw new ParamException(msg);
|
|
throw new ParamException(msg);
|
|
}
|
|
}
|
|
checkFieldData(fieldType, Lists.newArrayList(rowFieldData), false);
|
|
checkFieldData(fieldType, Lists.newArrayList(rowFieldData), false);
|
|
@@ -379,7 +402,7 @@ public class ParamUtils {
|
|
} else {
|
|
} else {
|
|
// check if autoId
|
|
// check if autoId
|
|
if (!fieldType.isAutoID()) {
|
|
if (!fieldType.isAutoID()) {
|
|
- String msg = "The field: " + fieldType.getName() + " is not provided.";
|
|
|
|
|
|
+ String msg = String.format("The field: %s is not provided.", fieldType.getName());
|
|
throw new ParamException(msg);
|
|
throw new ParamException(msg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -455,8 +478,16 @@ public class ParamUtils {
|
|
byte[] array = buf.array();
|
|
byte[] array = buf.array();
|
|
ByteString bs = ByteString.copyFrom(array);
|
|
ByteString bs = ByteString.copyFrom(array);
|
|
byteStrings.add(bs);
|
|
byteStrings.add(bs);
|
|
|
|
+ } else if (vector instanceof SortedMap) {
|
|
|
|
+ plType = PlaceholderType.SparseFloatVector;
|
|
|
|
+ SortedMap<Long, Float> map = (SortedMap<Long, Float>) vector;
|
|
|
|
+ ByteString bs = genSparseFloatBytes(map);
|
|
|
|
+ byteStrings.add(bs);
|
|
} else {
|
|
} else {
|
|
- String msg = "Search target vector type is illegal(Only allow List<Float> or ByteBuffer)";
|
|
|
|
|
|
+ String msg = "Search target vector type is illegal." +
|
|
|
|
+ " Only allow List<Float> for FloatVector," +
|
|
|
|
+ " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
|
|
|
|
+ " List<SortedMap<Long, Float>> for SparseFloatVector.";
|
|
throw new ParamException(msg);
|
|
throw new ParamException(msg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -623,6 +654,7 @@ public class ParamUtils {
|
|
add(DataType.BinaryVector);
|
|
add(DataType.BinaryVector);
|
|
add(DataType.Float16Vector);
|
|
add(DataType.Float16Vector);
|
|
add(DataType.BFloat16Vector);
|
|
add(DataType.BFloat16Vector);
|
|
|
|
+ add(DataType.SparseFloatVector);
|
|
}};
|
|
}};
|
|
return vectorDataType.contains(dataType);
|
|
return vectorDataType.contains(dataType);
|
|
}
|
|
}
|
|
@@ -631,7 +663,6 @@ public class ParamUtils {
|
|
return genFieldData(fieldType, objects, Boolean.FALSE);
|
|
return genFieldData(fieldType, objects, Boolean.FALSE);
|
|
}
|
|
}
|
|
|
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
|
private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
|
|
private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
|
|
if (objects == null) {
|
|
if (objects == null) {
|
|
throw new ParamException("Cannot generate FieldData from null object");
|
|
throw new ParamException("Cannot generate FieldData from null object");
|
|
@@ -694,11 +725,56 @@ public class ParamUtils {
|
|
} else {
|
|
} else {
|
|
return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
|
|
return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
|
|
}
|
|
}
|
|
|
|
+ } else if (dataType == DataType.SparseFloatVector) {
|
|
|
|
+ SparseFloatArray sparseArray = genSparseFloatArray(objects);
|
|
|
|
+ return VectorField.newBuilder().setDim(sparseArray.getDim()).setSparseFloatVector(sparseArray).build();
|
|
}
|
|
}
|
|
|
|
|
|
throw new ParamException("Illegal vector dataType:" + dataType);
|
|
throw new ParamException("Illegal vector dataType:" + dataType);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private static ByteString genSparseFloatBytes(SortedMap<Long, Float> sparse) {
|
|
|
|
+ ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size());
|
|
|
|
+ buf.order(ByteOrder.LITTLE_ENDIAN); // Milvus uses little endian by default
|
|
|
|
+ for (Map.Entry<Long, Float> entry : sparse.entrySet()) {
|
|
|
|
+ long k = entry.getKey();
|
|
|
|
+ if (k < 0 || k >= (long)Math.pow(2.0, 32.0)-1) {
|
|
|
|
+ throw new ParamException("Sparse vector index must be positive and less than 2^32-1");
|
|
|
|
+ }
|
|
|
|
+ // here we construct a binary from the long key
|
|
|
|
+ ByteBuffer lBuf = ByteBuffer.allocate(Long.BYTES);
|
|
|
|
+ lBuf.order(ByteOrder.LITTLE_ENDIAN);
|
|
|
|
+ lBuf.putLong(k);
|
|
|
|
+ // the server requires a binary of unsigned int, append the first 4 bytes
|
|
|
|
+ buf.put(lBuf.array(), 0, 4);
|
|
|
|
+
|
|
|
|
+ float v = entry.getValue();
|
|
|
|
+ if (Float.isNaN(v) || Float.isInfinite(v)) {
|
|
|
|
+ throw new ParamException("Sparse vector value cannot be NaN or Infinite");
|
|
|
|
+ }
|
|
|
|
+ buf.putFloat(entry.getValue());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return ByteString.copyFrom(buf.array());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private static SparseFloatArray genSparseFloatArray(List<?> objects) {
|
|
|
|
+ int dim = 0; // the real dim is unknown, set the max size as dim
|
|
|
|
+ SparseFloatArray.Builder builder = SparseFloatArray.newBuilder();
|
|
|
|
+ // each object must be SortedMap<Long, Float>, which is already validated by checkFieldData()
|
|
|
|
+ for (Object object : objects) {
|
|
|
|
+ if (!(object instanceof SortedMap)) {
|
|
|
|
+ throw new ParamException("SparseFloatVector vector field's value type must be SortedMap");
|
|
|
|
+ }
|
|
|
|
+ SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) object;
|
|
|
|
+ dim = Math.max(dim, sparse.size());
|
|
|
|
+ ByteString byteString = genSparseFloatBytes(sparse);
|
|
|
|
+ builder.addContents(byteString);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return builder.setDim(dim).build();
|
|
|
|
+ }
|
|
|
|
+
|
|
private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
|
|
private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
|
|
if (fieldType.getDataType() == DataType.Array) {
|
|
if (fieldType.getDataType() == DataType.Array) {
|
|
ArrayArray.Builder builder = ArrayArray.newBuilder();
|
|
ArrayArray.Builder builder = ArrayArray.newBuilder();
|