|
@@ -76,6 +76,8 @@ public class FieldDataWrapper {
|
|
|
}
|
|
|
|
|
|
// this method returns bytes size of each vector according to vector type
|
|
|
+ // for binary vector, each dimension is one bit, each byte is 8 dim
|
|
|
+ // for float16 vector, each dimension 2 bytes
|
|
|
private int checkDim(DataType dt, ByteString data, int dim) {
|
|
|
if (dt == DataType.BinaryVector) {
|
|
|
if ((data.size()*8) % dim != 0) {
|
|
@@ -96,6 +98,21 @@ public class FieldDataWrapper {
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
+ private ByteString getVectorBytes(FieldData fieldData, DataType dt) {
|
|
|
+ ByteString data;
|
|
|
+ if (dt == DataType.BinaryVector) {
|
|
|
+ data = fieldData.getVectors().getBinaryVector();
|
|
|
+ } else if (dt == DataType.Float16Vector) {
|
|
|
+ data = fieldData.getVectors().getFloat16Vector();
|
|
|
+ } else if (dt == DataType.BFloat16Vector) {
|
|
|
+ data = fieldData.getVectors().getBfloat16Vector();
|
|
|
+ } else {
|
|
|
+ String msg = String.format("Unsupported data type %s returned by FieldData", dt.name());
|
|
|
+ throw new IllegalResponseException(msg);
|
|
|
+ }
|
|
|
+ return data;
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Gets the row count of a field.
|
|
|
* * Throws {@link IllegalResponseException} if the field type is illegal.
|
|
@@ -116,20 +133,11 @@ public class FieldDataWrapper {
|
|
|
|
|
|
return data.size()/dim;
|
|
|
}
|
|
|
- case BinaryVector: {
|
|
|
- // for binary vector, each dimension is one bit, each byte is 8 dim
|
|
|
- int dim = getDim();
|
|
|
- ByteString data = fieldData.getVectors().getBinaryVector();
|
|
|
- int bytePerVec = checkDim(dt, data, dim);
|
|
|
-
|
|
|
- return data.size()/bytePerVec;
|
|
|
- }
|
|
|
+ case BinaryVector:
|
|
|
case Float16Vector:
|
|
|
case BFloat16Vector: {
|
|
|
- // for float16 vector, each dimension 2 bytes
|
|
|
int dim = getDim();
|
|
|
- ByteString data = (dt == DataType.Float16Vector) ?
|
|
|
- fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
|
|
|
+ ByteString data = getVectorBytes(fieldData, dt);
|
|
|
int bytePerVec = checkDim(dt, data, dim);
|
|
|
|
|
|
return data.size()/bytePerVec;
|
|
@@ -213,22 +221,14 @@ public class FieldDataWrapper {
|
|
|
case Float16Vector:
|
|
|
case BFloat16Vector: {
|
|
|
int dim = getDim();
|
|
|
- ByteString data = null;
|
|
|
- if (dt == DataType.BinaryVector) {
|
|
|
- data = fieldData.getVectors().getBinaryVector();
|
|
|
- } else if (dt == DataType.Float16Vector) {
|
|
|
- data = fieldData.getVectors().getFloat16Vector();
|
|
|
- } else {
|
|
|
- data = fieldData.getVectors().getBfloat16Vector();
|
|
|
- }
|
|
|
-
|
|
|
+ ByteString data = getVectorBytes(fieldData, dt);
|
|
|
int bytePerVec = checkDim(dt, data, dim);
|
|
|
int count = data.size()/bytePerVec;
|
|
|
List<ByteBuffer> packData = new ArrayList<>();
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
|
|
|
// binary vector doesn't care endian since each byte is independent
|
|
|
- // fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes,
|
|
|
+ // fp16/bf16 vector is sensitive to endian because each dim occupies 2 bytes,
|
|
|
// milvus server stores fp16/bf16 vector as little endian
|
|
|
bf.order(ByteOrder.LITTLE_ENDIAN);
|
|
|
bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
|