Răsfoiți Sursa

Refine code (#1409)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 5 zile în urmă
părinte
comite
06eb1dcafa

+ 5 - 2
sdk-core/src/main/java/io/milvus/param/ParamUtils.java

@@ -1138,17 +1138,20 @@ public class ParamUtils {
         }
     }
 
-    public static boolean isVectorDataType(DataType dataType) {
+    public static boolean isDenseVectorDataType(DataType dataType) {
         Set<DataType> vectorDataType = new HashSet<DataType>() {{
             add(DataType.FloatVector);
             add(DataType.BinaryVector);
             add(DataType.Float16Vector);
             add(DataType.BFloat16Vector);
-            add(DataType.SparseFloatVector);
         }};
         return vectorDataType.contains(dataType);
     }
 
+    public static boolean isVectorDataType(DataType dataType) {
+        return isDenseVectorDataType(dataType) || dataType == DataType.SparseFloatVector;
+    }
+
     public static FieldData genFieldData(FieldType fieldType, List<?> objects) {
         return genFieldData(fieldType, objects, Boolean.FALSE);
     }

+ 21 - 21
sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -76,6 +76,8 @@ public class FieldDataWrapper {
     }
 
     // this method returns bytes size of each vector according to vector type
+    // for binary vector, each dimension is one bit, each byte is 8 dim
+    // for float16 vector, each dimension 2 bytes
     private int checkDim(DataType dt, ByteString data, int dim) {
         if (dt == DataType.BinaryVector) {
             if ((data.size()*8) % dim != 0) {
@@ -96,6 +98,21 @@ public class FieldDataWrapper {
         return 0;
     }
 
+    private ByteString getVectorBytes(FieldData fieldData, DataType dt) {
+        ByteString data;
+        if (dt == DataType.BinaryVector) {
+            data = fieldData.getVectors().getBinaryVector();
+        } else if (dt == DataType.Float16Vector) {
+            data = fieldData.getVectors().getFloat16Vector();
+        } else if (dt == DataType.BFloat16Vector) {
+            data = fieldData.getVectors().getBfloat16Vector();
+        } else {
+            String msg = String.format("Unsupported data type %s returned by FieldData", dt.name());
+            throw new IllegalResponseException(msg);
+        }
+        return data;
+    }
+
     /**
      * Gets the row count of a field.
      * * Throws {@link IllegalResponseException} if the field type is illegal.
@@ -116,20 +133,11 @@ public class FieldDataWrapper {
 
                 return data.size()/dim;
             }
-            case BinaryVector: {
-                // for binary vector, each dimension is one bit, each byte is 8 dim
-                int dim = getDim();
-                ByteString data = fieldData.getVectors().getBinaryVector();
-                int bytePerVec = checkDim(dt, data, dim);
-
-                return data.size()/bytePerVec;
-            }
+            case BinaryVector:
             case Float16Vector:
             case BFloat16Vector: {
-                // for float16 vector, each dimension 2 bytes
                 int dim = getDim();
-                ByteString data = (dt == DataType.Float16Vector) ?
-                        fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
+                ByteString data = getVectorBytes(fieldData, dt);
                 int bytePerVec = checkDim(dt, data, dim);
 
                 return data.size()/bytePerVec;
@@ -213,22 +221,14 @@ public class FieldDataWrapper {
             case Float16Vector:
             case BFloat16Vector: {
                 int dim = getDim();
-                ByteString data = null;
-                if (dt == DataType.BinaryVector) {
-                    data = fieldData.getVectors().getBinaryVector();
-                } else if (dt == DataType.Float16Vector) {
-                    data = fieldData.getVectors().getFloat16Vector();
-                } else {
-                    data = fieldData.getVectors().getBfloat16Vector();
-                }
-
+                ByteString data = getVectorBytes(fieldData, dt);
                 int bytePerVec = checkDim(dt, data, dim);
                 int count = data.size()/bytePerVec;
                 List<ByteBuffer> packData = new ArrayList<>();
                 for (int i = 0; i < count; ++i) {
                     ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
                     // binary vector doesn't care endian since each byte is independent
-                    // fp16/bf16 vector is sensetive to endian because each dim occupies 2 bytes,
+                    // fp16/bf16 vector is sensitive to endian because each dim occupies 2 bytes,
                     // milvus server stores fp16/bf16 vector as little endian
                     bf.order(ByteOrder.LITTLE_ENDIAN);
                     bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());

+ 1 - 1
sdk-core/src/main/java/io/milvus/v2/service/collection/request/AddFieldReq.java

@@ -44,7 +44,7 @@ public class AddFieldReq {
     @Builder.Default
     private Boolean autoID = Boolean.FALSE;
     private Integer dimension;
-    private io.milvus.v2.common.DataType elementType;
+    private DataType elementType;
     private Integer maxCapacity;
     @Builder.Default
     private Boolean isNullable = Boolean.FALSE; // only for scalar fields(not include Array fields)

+ 2 - 2
sdk-core/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java

@@ -20,6 +20,7 @@
 package io.milvus.v2.service.collection.request;
 
 import io.milvus.common.clientenum.FunctionType;
+import io.milvus.param.ParamUtils;
 import io.milvus.v2.common.ConsistencyLevel;
 import io.milvus.v2.common.DataType;
 import io.milvus.v2.common.IndexParam;
@@ -166,8 +167,7 @@ public class CreateCollectionReq {
                 fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity());
             } else if (addFieldReq.getDataType().equals(DataType.VarChar)) {
                 fieldSchema.setMaxLength(addFieldReq.getMaxLength());
-            } else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector) ||
-                    addFieldReq.getDataType().equals(DataType.Float16Vector) || addFieldReq.getDataType().equals(DataType.BFloat16Vector)) {
+            } else if (ParamUtils.isDenseVectorDataType(io.milvus.grpc.DataType.valueOf(addFieldReq.getDataType().name()))) {
                 if (addFieldReq.getDimension() == null) {
                     throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field");
                 }

+ 2 - 0
sdk-core/src/test/java/io/milvus/TestUtils.java

@@ -11,6 +11,8 @@ public class TestUtils {
     private int dimension = 256;
     private static final Random RANDOM = new Random();
 
+    public static final String MilvusDockerImageID = "milvusdb/milvus:v2.5.11";
+
     public TestUtils(int dimension) {
         this.dimension = dimension;
     }

+ 1 - 1
sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -75,7 +75,7 @@ class MilvusClientDockerTest {
     private static final TestUtils utils = new TestUtils(DIMENSION);
 
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11");
+    private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID);
 
     @BeforeAll
     public static void setUp() {

+ 1 - 1
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -81,7 +81,7 @@ class MilvusClientV2DockerTest {
     private static final TestUtils utils = new TestUtils(DIMENSION);
 
     @Container
-    private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.5.11");
+    private static final MilvusContainer milvus = new MilvusContainer(TestUtils.MilvusDockerImageID);
 
     @BeforeAll
     public static void setUp() {