|
@@ -3,18 +3,18 @@ package io.milvus.response;
|
|
|
import com.alibaba.fastjson.JSONObject;
|
|
|
import com.google.protobuf.ProtocolStringList;
|
|
|
import io.milvus.exception.ParamException;
|
|
|
-import io.milvus.grpc.ArrayArray;
|
|
|
-import io.milvus.grpc.DataType;
|
|
|
-import io.milvus.grpc.FieldData;
|
|
|
+import io.milvus.grpc.*;
|
|
|
import io.milvus.exception.IllegalResponseException;
|
|
|
|
|
|
-import io.milvus.grpc.ScalarField;
|
|
|
import io.milvus.param.ParamUtils;
|
|
|
import lombok.NonNull;
|
|
|
|
|
|
import java.nio.ByteBuffer;
|
|
|
+import java.nio.ByteOrder;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
+import java.util.SortedMap;
|
|
|
+import java.util.TreeMap;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import com.google.protobuf.ByteString;
|
|
@@ -56,6 +56,27 @@ public class FieldDataWrapper {
|
|
|
return (int) fieldData.getVectors().getDim();
|
|
|
}
|
|
|
|
|
|
+ // this method returns bytes size of each vector according to vector type
|
|
|
+ private int checkDim(DataType dt, ByteString data, int dim) {
|
|
|
+ if (dt == DataType.BinaryVector) {
|
|
|
+ if ((data.size()*8) % dim != 0) {
|
|
|
+ String msg = String.format("Returned binary vector data array size %d doesn't match dimension %d",
|
|
|
+ data.size(), dim);
|
|
|
+ throw new IllegalResponseException(msg);
|
|
|
+ }
|
|
|
+ return dim/8;
|
|
|
+ } else if (dt == DataType.Float16Vector || dt == DataType.BFloat16Vector) {
|
|
|
+ if (data.size() % (dim*2) != 0) {
|
|
|
+ String msg = String.format("Returned float16 vector data array size %d doesn't match dimension %d",
|
|
|
+ data.size(), dim);
|
|
|
+ throw new IllegalResponseException(msg);
|
|
|
+ }
|
|
|
+ return dim*2;
|
|
|
+ }
|
|
|
+
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Gets the row count of a field.
|
|
|
* * Throws {@link IllegalResponseException} if the field type is illegal.
|
|
@@ -69,19 +90,34 @@ public class FieldDataWrapper {
|
|
|
int dim = getDim();
|
|
|
List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
|
|
|
if (data.size() % dim != 0) {
|
|
|
- throw new IllegalResponseException("Returned float vector field data array size doesn't match dimension");
|
|
|
+ String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
|
|
|
+ data.size(), dim);
|
|
|
+ throw new IllegalResponseException(msg);
|
|
|
}
|
|
|
|
|
|
return data.size()/dim;
|
|
|
}
|
|
|
case BinaryVector: {
|
|
|
+ // for binary vector, each dimension is one bit, each byte is 8 dim
|
|
|
int dim = getDim();
|
|
|
ByteString data = fieldData.getVectors().getBinaryVector();
|
|
|
- if ((data.size()*8) % dim != 0) {
|
|
|
- throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
|
|
|
- }
|
|
|
+ int bytePerVec = checkDim(dt, data, dim);
|
|
|
|
|
|
- return (data.size()*8)/dim;
|
|
|
+ return data.size()/bytePerVec;
|
|
|
+ }
|
|
|
+ case Float16Vector:
|
|
|
+ case BFloat16Vector: {
|
|
|
+ // for float16 vector, each dimension 2 bytes
|
|
|
+ int dim = getDim();
|
|
|
+ ByteString data = (dt == DataType.Float16Vector) ?
|
|
|
+ fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
|
|
|
+ int bytePerVec = checkDim(dt, data, dim);
|
|
|
+
|
|
|
+ return data.size()/bytePerVec;
|
|
|
+ }
|
|
|
+ case SparseFloatVector: {
|
|
|
+ // for sparse vector, each content is a vector
|
|
|
+ return fieldData.getVectors().getSparseFloatVector().getContentsCount();
|
|
|
}
|
|
|
case Int64:
|
|
|
return fieldData.getScalars().getLongData().getDataCount();
|
|
@@ -109,15 +145,17 @@ public class FieldDataWrapper {
|
|
|
|
|
|
/**
|
|
|
* Returns the field data according to its type:
|
|
|
- * float vector field return List of List Float,
|
|
|
- * binary vector field return List of ByteBuffer
|
|
|
- * int64 field return List of Long
|
|
|
- * int32/int16/int8 field return List of Integer
|
|
|
- * boolean field return List of Boolean
|
|
|
- * float field return List of Float
|
|
|
- * double field return List of Double
|
|
|
- * varchar field return List of String
|
|
|
- * array field return List of List
|
|
|
+ * FloatVector field returns List of List Float,
|
|
|
+ * BinaryVector/Float16Vector/BFloat16Vector fields return List of ByteBuffer
|
|
|
+ * SparseFloatVector field returns List of SortedMap[Long, Float]
|
|
|
+ * Int64 field returns List of Long
|
|
|
+ * Int32/Int16/Int8 fields return List of Integer
|
|
|
+ * Bool field returns List of Boolean
|
|
|
+ * Float field returns List of Float
|
|
|
+ * Double field returns List of Double
|
|
|
+ * Varchar field returns List of String
|
|
|
+ * Array field returns List of List
|
|
|
+ * JSON field returns List of String;
|
|
|
* etc.
|
|
|
*
|
|
|
* Throws {@link IllegalResponseException} if the field type is illegal.
|
|
@@ -131,7 +169,9 @@ public class FieldDataWrapper {
|
|
|
int dim = getDim();
|
|
|
List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
|
|
|
if (data.size() % dim != 0) {
|
|
|
- throw new IllegalResponseException("Returned float vector field data array size doesn't match dimension");
|
|
|
+ String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
|
|
|
+ data.size(), dim);
|
|
|
+ throw new IllegalResponseException(msg);
|
|
|
}
|
|
|
|
|
|
List<List<Float>> packData = new ArrayList<>();
|
|
@@ -141,16 +181,22 @@ public class FieldDataWrapper {
|
|
|
}
|
|
|
return packData;
|
|
|
}
|
|
|
- case BinaryVector: {
|
|
|
+ case BinaryVector:
|
|
|
+ case Float16Vector:
|
|
|
+ case BFloat16Vector: {
|
|
|
int dim = getDim();
|
|
|
- ByteString data = fieldData.getVectors().getBinaryVector();
|
|
|
- if ((data.size()*8) % dim != 0) {
|
|
|
- throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
|
|
|
+ ByteString data = null;
|
|
|
+ if (dt == DataType.BinaryVector) {
|
|
|
+ data = fieldData.getVectors().getBinaryVector();
|
|
|
+ } else if (dt == DataType.Float16Vector) {
|
|
|
+ data = fieldData.getVectors().getFloat16Vector();
|
|
|
+ } else {
|
|
|
+ data = fieldData.getVectors().getBfloat16Vector();
|
|
|
}
|
|
|
|
|
|
- List<ByteBuffer> packData = new ArrayList<>();
|
|
|
- int bytePerVec = dim/8;
|
|
|
+ int bytePerVec = checkDim(dt, data, dim);
|
|
|
int count = data.size()/bytePerVec;
|
|
|
+ List<ByteBuffer> packData = new ArrayList<>();
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
|
|
|
bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
|
|
@@ -158,6 +204,40 @@ public class FieldDataWrapper {
|
|
|
}
|
|
|
return packData;
|
|
|
}
|
|
|
+ case SparseFloatVector: {
|
|
|
+ // in Java sdk, each sparse vector is pairs of long+float
|
|
|
+ // in server side, each sparse vector is stored as uint+float (8 bytes)
|
|
|
+ // don't use sparseArray.getDim() because the dim is the max index of each rows
|
|
|
+ SparseFloatArray sparseArray = fieldData.getVectors().getSparseFloatVector();
|
|
|
+ List<SortedMap<Long, Float>> packData = new ArrayList<>();
|
|
|
+ for (int i = 0; i < sparseArray.getContentsCount(); ++i) {
|
|
|
+ ByteString bs = sparseArray.getContents(i);
|
|
|
+ ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray());
|
|
|
+ bf.order(ByteOrder.LITTLE_ENDIAN);
|
|
|
+ SortedMap<Long, Float> sparse = new TreeMap<>();
|
|
|
+ long num = bf.limit()/8; // each uint+float pair is 8 bytes
|
|
|
+ for (long j = 0; j < num; j++) {
|
|
|
+ // here we convert an uint 4-bytes to a long value
|
|
|
+ ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
|
|
|
+ pBuf.order(ByteOrder.LITTLE_ENDIAN);
|
|
|
+ int offset = 8*(int)j;
|
|
|
+ byte[] aa = bf.array();
|
|
|
+ for (int k = offset; k < offset + 4; k++) {
|
|
|
+ pBuf.put(aa[k]); // fill the first 4 bytes with the unit bytes
|
|
|
+ }
|
|
|
+ pBuf.putInt(0); // fill the last 4 bytes to zero
|
|
|
+ pBuf.rewind(); // reset position to head
|
|
|
+ long k = pBuf.getLong(); // this is the long value converted from the uint
|
|
|
+
|
|
|
+ // here we get the float value as normal
|
|
|
+ bf.position(offset+4); // position offsets 4 bytes since they were converted to long
|
|
|
+ float v = bf.getFloat(); // this is the float value
|
|
|
+ sparse.put(k, v);
|
|
|
+ }
|
|
|
+ packData.add(sparse);
|
|
|
+ }
|
|
|
+ return packData;
|
|
|
+ }
|
|
|
case Array:
|
|
|
List<List<?>> array = new ArrayList<>();
|
|
|
ArrayArray arrArray = fieldData.getScalars().getArrayData();
|
|
@@ -202,7 +282,7 @@ public class FieldDataWrapper {
|
|
|
return protoStrList.subList(0, protoStrList.size());
|
|
|
case JSON:
|
|
|
List<ByteString> dataList = scalar.getJsonData().getDataList();
|
|
|
- return dataList.stream().map(ByteString::toByteArray).collect(Collectors.toList());
|
|
|
+ return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
|
|
|
default:
|
|
|
return new ArrayList<>();
|
|
|
}
|
|
@@ -249,14 +329,25 @@ public class FieldDataWrapper {
|
|
|
}
|
|
|
|
|
|
public Object valueByIdx(int index) throws ParamException {
|
|
|
- if (index < 0 || index >= getFieldData().size()) {
|
|
|
- throw new ParamException("index out of range");
|
|
|
+ List<?> data = getFieldData();
|
|
|
+ if (index < 0 || index >= data.size()) {
|
|
|
+ throw new ParamException(String.format("Value index %d out of range %d", index, data.size()));
|
|
|
}
|
|
|
- return getFieldData().get(index);
|
|
|
+ return data.get(index);
|
|
|
}
|
|
|
|
|
|
private JSONObject parseObjectData(int index) {
|
|
|
Object object = valueByIdx(index);
|
|
|
- return JSONObject.parseObject(new String((byte[])object));
|
|
|
+ return ParseJSONObject(object);
|
|
|
+ }
|
|
|
+
|
|
|
+ public static JSONObject ParseJSONObject(Object object) {
|
|
|
+ if (object instanceof String) {
|
|
|
+ return JSONObject.parseObject((String)object);
|
|
|
+ } else if (object instanceof byte[]) {
|
|
|
+ return JSONObject.parseObject(new String((byte[]) object));
|
|
|
+ } else {
|
|
|
+ throw new IllegalResponseException("Illegal type value for JSON parser");
|
|
|
+ }
|
|
|
}
|
|
|
}
|