Jelajahi Sumber

Support BFloat16/Float16 vector (#812)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 tahun lalu
induk
melakukan
450a6c18c4

+ 10 - 9
README.md

@@ -14,12 +14,13 @@ Java SDK for [Milvus](https://github.com/milvus-io/milvus). To contribute to thi
 The following table shows compatibilities between Milvus and Java SDK.
 
 | Milvus version | Java SDK version |
-| :------------: |:----------------:|
-|     2.0      |      2.0.4       |
-|     2.1      |   2.1.0-beta4    |
-|     2.2.0 ~ 2.2.8      |      2.2.0 ~ 2.2.5       |
-|     >= 2.2.9      |      2.2.7 ~ 2.2.15       |
-|     2.3.x      |      2.3.3      |
+|:--------------:|:----------------:|
+|      2.0       |      2.0.4       |
+|      2.1       |   2.1.0-beta4    |
+| 2.2.0 ~ 2.2.8  |  2.2.0 ~ 2.2.5   |
+|    >= 2.2.9    |  2.2.7 ~ 2.2.15  |
+|     2.3.x      |      2.3.4       |
+|     2.4.x      |      2.4.0       |
 
 ### Install Java SDK
 
@@ -31,20 +32,20 @@ You can use **Apache Maven** or **Gradle** add Milvus SDK to your project.
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>2.3.3</version>
+            <version>2.4.0</version>
         </dependency>
        ```
 
    - Gradle/Groovy
 
         ```groovy
-        implementation 'io.milvus:milvus-sdk-java:2.3.3'
+        implementation 'io.milvus:milvus-sdk-java:2.4.0'
         ```
 
    - Gradle/Kotlin
 
         ```kotlin
-        implementation("io.milvus:milvus-sdk-java:2.3.3")
+        implementation("io.milvus:milvus-sdk-java:2.4.0")
         ```
         
 ### Examples

+ 2 - 2
docker-compose.yml

@@ -32,7 +32,7 @@ services:
 
   standalone:
     container_name: milvus-javasdk-test-standalone
-    image: milvusdb/milvus:master-20240229-50a78b68
+    image: milvusdb/milvus:v2.4.0-rc.1
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcd:2379
@@ -77,7 +77,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-test-slave-standalone
-    image: milvusdb/milvus:master-20240229-50a78b68
+    image: milvusdb/milvus:v2.4.0-rc.1
     command: ["milvus", "run", "standalone"]
     environment:
       ETCD_ENDPOINTS: etcdslave:2379

+ 207 - 0
examples/main/java/io/milvus/Float16Example.java

@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus;
+
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.common.clientenum.ConsistencyLevelEnum;
+import io.milvus.grpc.*;
+import io.milvus.param.*;
+import io.milvus.param.collection.*;
+import io.milvus.param.dml.*;
+import io.milvus.param.index.*;
+import io.milvus.response.*;
+
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import org.tensorflow.ndarray.buffer.ByteDataBuffer;
+import org.tensorflow.types.*;
+
+
+public class Float16Example {
+    private static final String COLLECTION_NAME = "java_sdk_example_float16";
+    private static final String ID_FIELD = "id";
+    private static final String VECTOR_FIELD = "vector";
+    private static final Integer VECTOR_DIM = 128;
+
+    private static List<ByteBuffer> generateVectors(int count, boolean bfloat16) {
+        Random ran = new Random();
+        List<ByteBuffer> vectors = new ArrayList<>();
+        int byteCount = VECTOR_DIM*2;
+        for (int n = 0; n < count; ++n) {
+            ByteBuffer vector = ByteBuffer.allocate(byteCount);
+            for (int i = 0; i < VECTOR_DIM; ++i) {
+                ByteDataBuffer bf = null;
+                if (bfloat16) {
+                    TFloat16 tt = TFloat16.scalarOf((float)ran.nextInt(VECTOR_DIM));
+                    bf = tt.asRawTensor().data();
+                } else {
+                    TBfloat16 tt = TBfloat16.scalarOf((float)ran.nextInt(VECTOR_DIM));
+                    bf = tt.asRawTensor().data();
+                }
+                vector.put(bf.getByte(0));
+                vector.put(bf.getByte(1));
+            }
+            vectors.add(vector);
+        }
+
+        return vectors;
+    }
+
+    private static void handleResponseStatus(R<?> r) {
+        if (r.getStatus() != R.Status.Success.getCode()) {
+            throw new RuntimeException(r.getMessage());
+        }
+    }
+
+    private static void testFloat16(boolean bfloat16) {
+        DataType dataType = bfloat16 ? DataType.BFloat16Vector : DataType.Float16Vector;
+        System.out.printf("=================== %s ===================\n", dataType.name());
+
+        // Connect to Milvus server. Replace the "localhost" and port with your Milvus server address.
+        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .build());
+
+        // drop the collection if you don't need the collection anymore
+        R<Boolean> hasR = milvusClient.hasCollection(HasCollectionParam.newBuilder()
+                                    .withCollectionName(COLLECTION_NAME)
+                                    .build());
+        handleResponseStatus(hasR);
+        if (hasR.getData()) {
+            milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .build());
+        }
+
+        // Define fields
+        List<FieldType> fieldsSchema = Arrays.asList(
+                FieldType.newBuilder()
+                        .withName(ID_FIELD)
+                        .withDataType(DataType.Int64)
+                        .withPrimaryKey(true)
+                        .withAutoID(false)
+                        .build(),
+                FieldType.newBuilder()
+                        .withName(VECTOR_FIELD)
+                        .withDataType(dataType)
+                        .withDimension(VECTOR_DIM)
+                        .build()
+        );
+
+        // Create the collection
+        R<RpcStatus> ret = milvusClient.createCollection(CreateCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
+                .withFieldTypes(fieldsSchema)
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Collection created");
+
+        // Insert entities
+        int rowCount = 10000;
+        List<Long> ids = new ArrayList<>();
+        for (long i = 0L; i < rowCount; ++i) {
+            ids.add(i);
+        }
+        List<ByteBuffer> vectors = generateVectors(rowCount, bfloat16);
+
+        List<InsertParam.Field> fieldsInsert = new ArrayList<>();
+        fieldsInsert.add(new InsertParam.Field(ID_FIELD, ids));
+        fieldsInsert.add(new InsertParam.Field(VECTOR_FIELD, vectors));
+
+        InsertParam insertParam = InsertParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFields(fieldsInsert)
+                .build();
+
+        R<MutationResult> insertR = milvusClient.insert(insertParam);
+        handleResponseStatus(insertR);
+
+        // Flush the data to storage for testing purpose
+        // Note that no need to manually call flush interface in practice
+        R<FlushResponse> flushR = milvusClient.flush(FlushParam.newBuilder().
+                addCollectionName(COLLECTION_NAME).
+                build());
+        handleResponseStatus(flushR);
+        System.out.println("Entities inserted");
+
+        // Specify an index type on the vector field.
+        ret = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFieldName(VECTOR_FIELD)
+                .withIndexType(IndexType.IVF_FLAT)
+                .withMetricType(MetricType.L2)
+                .withExtraParam("{\"nlist\":128}")
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Index created");
+
+        // Call loadCollection() to enable automatically loading data into memory for searching
+        ret = milvusClient.loadCollection(LoadCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+        handleResponseStatus(ret);
+        System.out.println("Collection loaded");
+
+        // Pick some vectors from the inserted vectors to search
+        // Ensure the returned top1 item's ID should be equal to target vector's ID
+        for (int i = 0; i < 10; i++) {
+            Random ran = new Random();
+            int k = ran.nextInt(rowCount);
+            ByteBuffer targetVector = vectors.get(k);
+            R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
+                    .withCollectionName(COLLECTION_NAME)
+                    .withMetricType(MetricType.L2)
+                    .withTopK(3)
+                    .withVectors(Collections.singletonList(targetVector))
+                    .withVectorFieldName(VECTOR_FIELD)
+                    .withParams("{\"nprobe\":32}")
+                    .build());
+            handleResponseStatus(ret);
+
+            // The search() allows multiple target vectors to search in a batch.
+            // Here we only input one vector to search, get the result of No.0 vector to check
+            SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
+            List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
+            System.out.printf("The result of No.%d target vector:\n", i);
+            for (SearchResultsWrapper.IDScore score : scores) {
+                System.out.println(score);
+            }
+            if (scores.get(0).getLongID() != k) {
+                throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
+                        scores.get(0).getLongID(), k));
+            }
+        }
+
+        // drop the collection if you don't need the collection anymore
+        milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .build());
+
+        milvusClient.close();
+    }
+
+    public static void main(String[] args) {
+        testFloat16(true);
+        testFloat16(false);
+    }
+}

+ 12 - 5
examples/pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java-examples</artifactId>
-    <version>2.3.2</version>
+    <version>2.4.0</version>
 
     <build>
         <plugins>
@@ -64,7 +64,14 @@
         <dependency>
             <groupId>io.milvus</groupId>
             <artifactId>milvus-sdk-java</artifactId>
-            <version>2.3.2</version>
+            <version>2.4.0</version>
+
+            <exclusions>
+                <exclusion>
+                    <groupId>org.slf4j</groupId>
+                    <artifactId>slf4j-log4j12</artifactId>
+                </exclusion>
+            </exclusions>
         </dependency>
         <dependency>
             <groupId>com.google.code.gson</groupId>
@@ -72,9 +79,9 @@
             <version>2.8.9</version>
         </dependency>
         <dependency>
-            <groupId>org.slf4j</groupId>
-            <artifactId>slf4j-api</artifactId>
-            <version>1.7.30</version>
+            <groupId>org.tensorflow</groupId>
+            <artifactId>tensorflow-core-platform</artifactId>
+            <version>0.5.0</version>
         </dependency>
     </dependencies>
 

+ 1 - 1
pom.xml

@@ -25,7 +25,7 @@
 
     <groupId>io.milvus</groupId>
     <artifactId>milvus-sdk-java</artifactId>
-    <version>2.3.2</version>
+    <version>2.4.0</version>
     <packaging>jar</packaging>
 
     <name>io.milvus:milvus-sdk-java</name>

+ 41 - 11
src/main/java/io/milvus/param/ParamUtils.java

@@ -44,6 +44,8 @@ public class ParamUtils {
         typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be String");
         typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be List<Float>");
         typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
+        typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
+        typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
         return typeErrMsg;
     }
 
@@ -52,6 +54,18 @@ public class ParamUtils {
         checkFieldData(fieldSchema, values, false);
     }
 
+    private static int calculateBinVectorDim(DataType dataType, int byteCount) {
+        if (dataType == DataType.BinaryVector) {
+            return byteCount*8; // for BinaryVector, each byte is 8 dimensions
+        } else {
+            if (byteCount%2 != 0) {
+                String msg = "Incorrect byte count for %s type field, byte count is %d, cannot be evenly divided by 2";
+                throw new ParamException(String.format(msg, dataType.name(), byteCount));
+            }
+            return byteCount/2; // for float16/bfloat16, each dimension is 2 bytes
+        }
+    }
+
     public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean verifyElementType) {
         HashMap<DataType, String> errMsgs = getTypeErrorMsg();
         DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType();
@@ -86,7 +100,10 @@ public class ParamUtils {
                 }
             }
             break;
-            case BinaryVector: {
+            case BinaryVector:
+            case Float16Vector:
+            case BFloat16Vector:
+            {
                 int dim = fieldSchema.getDimension();
                 for (int i = 0; i < values.size(); ++i) {
                     Object value  = values.get(i);
@@ -97,7 +114,8 @@ public class ParamUtils {
 
                     // check dimension
                     ByteBuffer v = (ByteBuffer)value;
-                    if (v.position()*8 != dim) {
+                    int real_dim = calculateBinVectorDim(dataType, v.position());
+                    if (real_dim != dim) {
                         String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d";
                         throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
                     }
@@ -599,11 +617,15 @@ public class ParamUtils {
         return guaranteeTimestamp;
     }
 
-
-    private static final Set<DataType> vectorDataType = new HashSet<DataType>() {{
-        add(DataType.FloatVector);
-        add(DataType.BinaryVector);
-    }};
+    public static boolean isVectorDataType(DataType dataType) {
+        Set<DataType> vectorDataType = new HashSet<DataType>() {{
+            add(DataType.FloatVector);
+            add(DataType.BinaryVector);
+            add(DataType.Float16Vector);
+            add(DataType.BFloat16Vector);
+        }};
+        return vectorDataType.contains(dataType);
+    }
 
     private static FieldData genFieldData(FieldType fieldType, List<?> objects) {
         return genFieldData(fieldType, objects, Boolean.FALSE);
@@ -617,7 +639,7 @@ public class ParamUtils {
         DataType dataType = fieldType.getDataType();
         String fieldName = fieldType.getName();
         FieldData.Builder builder = FieldData.newBuilder();
-        if (vectorDataType.contains(dataType)) {
+        if (isVectorDataType(dataType)) {
             VectorField vectorField = genVectorField(dataType, objects);
             return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
         } else {
@@ -646,7 +668,9 @@ public class ParamUtils {
             int dim = floats.size() / objects.size();
             FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
             return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
-        } else if (dataType == DataType.BinaryVector) {
+        } else if (dataType == DataType.BinaryVector ||
+                dataType == DataType.Float16Vector ||
+                dataType == DataType.BFloat16Vector) {
             ByteBuffer totalBuf = null;
             int dim = 0;
             // each object is ByteBuffer
@@ -655,7 +679,7 @@ public class ParamUtils {
                 if (totalBuf == null) {
                     totalBuf = ByteBuffer.allocate(buf.position() * objects.size());
                     totalBuf.put(buf.array());
-                    dim = buf.position() * 8;
+                    dim = calculateBinVectorDim(dataType, buf.position());
                 } else {
                     totalBuf.put(buf.array());
                 }
@@ -663,7 +687,13 @@ public class ParamUtils {
 
             assert totalBuf != null;
             ByteString byteString = ByteString.copyFrom(totalBuf.array());
-            return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
+            if (dataType == DataType.BinaryVector) {
+                return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
+            } else if (dataType == DataType.Float16Vector) {
+                return VectorField.newBuilder().setDim(dim).setFloat16Vector(byteString).build();
+            } else {
+                return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
+            }
         }
 
         throw new ParamException("Illegal vector dataType:" + dataType);

+ 1 - 1
src/main/java/io/milvus/param/QueryNodeSingleSearch.java

@@ -100,7 +100,7 @@ public class QueryNodeSingleSearch {
          *
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float
-         *                if vector type is BinaryVector, vectors is List of ByteBuffer
+         *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer
          * @return <code>Builder</code>
          */
         public Builder withVectors(@NonNull List<?> vectors) {

+ 3 - 4
src/main/java/io/milvus/param/collection/FieldType.java

@@ -271,7 +271,7 @@ public class FieldType {
                 throw new ParamException("String type is not supported, use Varchar instead");
             }
 
-            if (dataType == DataType.FloatVector || dataType == DataType.BinaryVector) {
+            if (ParamUtils.isVectorDataType(dataType)) {
                 if (!typeParams.containsKey(Constant.VECTOR_DIM)) {
                     throw new ParamException("Vector field dimension must be specified");
                 }
@@ -317,9 +317,8 @@ public class FieldType {
                     throw new ParamException("String type is not supported, use Varchar instead");
                 }
                 if (elementType == DataType.None || elementType == DataType.Array
-                        || elementType == DataType.JSON || elementType == DataType.String
-                        || elementType == DataType.FloatVector || elementType == DataType.Float16Vector
-                        || elementType == DataType.BinaryVector || elementType == DataType.UNRECOGNIZED) {
+                        || elementType == DataType.JSON || ParamUtils.isVectorDataType(elementType)
+                        || elementType == DataType.UNRECOGNIZED) {
                     throw new ParamException("Unsupported element type");
                 }
 

+ 1 - 1
src/main/java/io/milvus/param/dml/InsertParam.java

@@ -217,7 +217,7 @@ public class InsertParam {
      * If dataType is Double, values is List of Double;
      * If dataType is Varchar, values is List of String;
      * If dataType is FloatVector, values is List of List Float;
-     * If dataType is BinaryVector, values is List of ByteBuffer;
+     * If dataType is BinaryVector/Float16Vector/BFloat16Vector, values is List of ByteBuffer;
      * If dataType is Array, values can be List of List Boolean/Integer/Short/Long/Float/Double/String;
      *
      * Note:

+ 1 - 1
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -237,7 +237,7 @@ public class SearchParam {
          *
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float;
-         *                if vector type is BinaryVector, vectors is List of ByteBuffer;
+         *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
          * @return <code>Builder</code>
          */
         public Builder withVectors(@NonNull List<?> vectors) {

+ 1 - 1
src/main/java/io/milvus/param/highlevel/dml/SearchSimpleParam.java

@@ -122,7 +122,7 @@ public class SearchSimpleParam {
          *
          * @param vectors list of target vectors:
          *               if vector type is FloatVector, vectors is List of List Float;
-         *               if vector type is BinaryVector, vectors is List of ByteBuffer;
+         *               if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
          * @return <code>Builder</code>
          */
         public Builder withVectors(@NonNull List<?> vectors) {

+ 1 - 1
src/main/java/io/milvus/response/DescCollResponseWrapper.java

@@ -177,7 +177,7 @@ public class DescCollResponseWrapper {
         CollectionSchema schema = response.getSchema();
         for (int i = 0; i < schema.getFieldsCount(); ++i) {
             FieldSchema field = schema.getFields(i);
-            if (field.getDataType() == DataType.FloatVector || field.getDataType() == DataType.BinaryVector) {
+            if (ParamUtils.isVectorDataType(field.getDataType())) {
                 return ParamUtils.ConvertField(field);
             }
         }

+ 2 - 1
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -9,6 +9,7 @@ import io.milvus.grpc.FieldData;
 import io.milvus.exception.IllegalResponseException;
 
 import io.milvus.grpc.ScalarField;
+import io.milvus.param.ParamUtils;
 import lombok.NonNull;
 
 import java.nio.ByteBuffer;
@@ -31,7 +32,7 @@ public class FieldDataWrapper {
     }
 
     public boolean isVectorField() {
-        return fieldData.getType() == DataType.FloatVector || fieldData.getType() == DataType.BinaryVector;
+        return ParamUtils.isVectorDataType(fieldData.getType());
     }
 
     public boolean isJsonField() {

+ 2 - 0
src/main/java/io/milvus/v2/utils/DataUtils.java

@@ -459,6 +459,8 @@ public class DataUtils {
         typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be String");
         typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be List<Float>");
         typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
+        typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
+        typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
         return typeErrMsg;
     }
 }