Browse Source

BulkWriter supports Int8Vector (#1441)

* Add Development section to README.md (#1436)

- Include link to DEVELOPMENT.md
- Briefly outline key development topics
- Encourage community contributions

Signed-off-by: Gong Yi <topikachu@163.com>

* BulkWriter supports Int8Vector (#1440)

Signed-off-by: yhmo <yihua.mo@zilliz.com>

---------

Signed-off-by: Gong Yi <topikachu@163.com>
Signed-off-by: yhmo <yihua.mo@zilliz.com>
Co-authored-by: GongYi <topikachu@163.com>
groot 1 day ago
parent
commit
40e7215c61

+ 12 - 1
README.md

@@ -114,4 +114,15 @@ Please refer to [examples](https://github.com/milvus-io/milvus-sdk-java/tree/mas
         implementation("org.slf4j:slf4j-api:1.7.30")
         ```
 
-    
+### Development
+
+For developers interested in contributing to the Milvus Java SDK, please refer to our [DEVELOPMENT.md](DEVELOPMENT.md) file. This document provides detailed instructions on:
+
+- Setting up your development environment
+- Cloning the repository
+- Building the SDK
+- Updating Milvus proto files
+- Running tests
+- Contributing guidelines
+
+We welcome contributions from the community!

+ 21 - 0
examples/src/main/java/io/milvus/v1/CommonUtils.java

@@ -262,6 +262,27 @@ public class CommonUtils {
         return vectors;
     }
 
+    /////////////////////////////////////////////////////////////////////////////////////////////////////
+    public static ByteBuffer generateInt8Vector(int dimension) {
+        Random ran = new Random();
+        int byteCount = dimension;
+        // binary vector doesn't care endian since each byte is independent
+        ByteBuffer vector = ByteBuffer.allocate(byteCount);
+        for (int i = 0; i < byteCount; ++i) {
+            vector.put((byte) (ran.nextInt(256) - 128));
+        }
+        return vector;
+    }
+
+    public static List<ByteBuffer> generateInt8Vectors(int dimension, int count) {
+        List<ByteBuffer> vectors = new ArrayList<>();
+        for (int n = 0; n < count; ++n) {
+            ByteBuffer vector = generateInt8Vector(dimension);
+            vectors.add(vector);
+        }
+        return vectors;
+    }
+
     /////////////////////////////////////////////////////////////////////////////////////////////////////
     public static SortedMap<Long, Float> generateSparseVector() {
         Random ran = new Random();

+ 9 - 9
examples/src/main/java/io/milvus/v2/BulkWriterExample.java

@@ -373,7 +373,7 @@ public class BulkWriterExample {
             // vector field
             row.put("float_vector", CommonUtils.generateFloatVector(DIM));
             row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array());
-            row.put("float16_vector", CommonUtils.generateFloat16Vector(DIM, false).array());
+            row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
             row.put("sparse_vector", CommonUtils.generateSparseVector());
 
             // array field
@@ -405,7 +405,7 @@ public class BulkWriterExample {
             // vector field
             row.put("float_vector", CommonUtils.generateFloatVector(DIM));
             row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array());
-            row.put("float16_vector", CommonUtils.generateFloat16Vector(DIM, false).array());
+            row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
             row.put("sparse_vector", CommonUtils.generateSparseVector());
 
             // array field
@@ -450,7 +450,7 @@ public class BulkWriterExample {
             // vector field
             rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(row.get("float_vector")));
             rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(row.get("binary_vector")));
-            rowObject.add("float16_vector", GSON_INSTANCE.toJsonTree(row.get("float16_vector")));
+            rowObject.add("int8_vector", GSON_INSTANCE.toJsonTree(row.get("int8_vector")));
             rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(row.get("sparse_vector")));
 
             // array field
@@ -791,7 +791,7 @@ public class BulkWriterExample {
 
             comparePrint(collectionSchema, originalEntity, fetchedEntity, "float_vector");
             comparePrint(collectionSchema, originalEntity, fetchedEntity, "binary_vector");
-            comparePrint(collectionSchema, originalEntity, fetchedEntity, "float16_vector");
+            comparePrint(collectionSchema, originalEntity, fetchedEntity, "int8_vector");
             comparePrint(collectionSchema, originalEntity, fetchedEntity, "sparse_vector");
 
             System.out.println(fetchedEntity);
@@ -815,9 +815,9 @@ public class BulkWriterExample {
                 .metricType(IndexParam.MetricType.HAMMING)
                 .build());
         indexes.add(IndexParam.builder()
-                .fieldName("float16_vector")
-                .indexType(IndexParam.IndexType.FLAT)
-                .metricType(IndexParam.MetricType.IP)
+                .fieldName("int8_vector")
+                .indexType(IndexParam.IndexType.AUTOINDEX)
+                .metricType(IndexParam.MetricType.L2)
                 .build());
         indexes.add(IndexParam.builder()
                 .fieldName("sparse_vector")
@@ -992,8 +992,8 @@ public class BulkWriterExample {
                 .dimension(DIM)
                 .build());
         schemaV2.addField(AddFieldReq.builder()
-                .fieldName("float16_vector")
-                .dataType(DataType.Float16Vector)
+                .fieldName("int8_vector")
+                .dataType(DataType.Int8Vector)
                 .dimension(DIM)
                 .build());
         schemaV2.addField(AddFieldReq.builder()

+ 3 - 1
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkWriter.java

@@ -289,7 +289,8 @@ public abstract class BulkWriter implements AutoCloseable {
                 case FloatVector:
                 case Float16Vector:
                 case BFloat16Vector:
-                case SparseFloatVector: {
+                case SparseFloatVector:
+                case Int8Vector:{
                     Pair<Object, Integer> objectAndSize = verifyVector(obj, field);
                     rowValues.put(fieldName, objectAndSize.getLeft());
                     rowSize += objectAndSize.getRight();
@@ -368,6 +369,7 @@ public abstract class BulkWriter implements AutoCloseable {
             case FloatVector:
                 return Pair.of(vector, ((List<?>) vector).size() * 4);
             case BinaryVector:
+            case Int8Vector:
                 return Pair.of(vector, ((ByteBuffer)vector).limit());
             case Float16Vector:
             case BFloat16Vector:

+ 3 - 1
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/ParquetUtils.java

@@ -81,8 +81,10 @@ public class ParquetUtils {
                 case BinaryVector:
                 case Float16Vector:
                 case BFloat16Vector:
+                case Int8Vector:
+                    boolean isSigned = (field.getDataType() == io.milvus.v2.common.DataType.Int8Vector);
                     setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
-                            LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false), field, true);
+                            LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, isSigned), field, true);
                     break;
                 case Array:
                     fillArrayType(messageTypeBuilder, field);

+ 1 - 0
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/writer/ParquetFileWriter.java

@@ -144,6 +144,7 @@ public class ParquetFileWriter implements FormatFileWriter {
             case BinaryVector:
             case Float16Vector:
             case BFloat16Vector:
+            case Int8Vector:
                 addBinaryVector(group, paramName, (ByteBuffer) value);
                 break;
             case SparseFloatVector:

+ 13 - 0
sdk-bulkwriter/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java

@@ -202,6 +202,11 @@ public class BulkWriterTest {
                 .fieldName("sparse_vector_field")
                 .dataType(DataType.SparseFloatVector)
                 .build());
+        schemaV2.addField(AddFieldReq.builder()
+                .fieldName("int8_vector_field")
+                .dataType(DataType.Int8Vector)
+                .dimension(DIMENSION)
+                .build());
         return schemaV2;
     }
 
@@ -274,6 +279,7 @@ public class BulkWriterTest {
             rowObject.add("float_vector_field", JsonUtils.toJsonTree(utils.generateFloatVector()));
             rowObject.add("binary_vector_field", JsonUtils.toJsonTree(utils.generateBinaryVector().array()));
             rowObject.add("sparse_vector_field", JsonUtils.toJsonTree(utils.generateSparseVector()));
+            rowObject.add("int8_vector_field", JsonUtils.toJsonTree(utils.generateInt8Vector().array()));
 
             rows.add(rowObject);
         }
@@ -368,6 +374,7 @@ public class BulkWriterTest {
                 rowObject.add("float_vector_field", JsonUtils.toJsonTree(utils.generateFloatVector()));
                 rowObject.add("binary_vector_field", JsonUtils.toJsonTree(utils.generateBinaryVector().array()));
                 rowObject.add("sparse_vector_field", JsonUtils.toJsonTree(utils.generateSparseVector()));
+                rowObject.add("int8_vector_field", JsonUtils.toJsonTree(utils.generateInt8Vector().array()));
                 rowObject.add("arr_int32_field", JsonUtils.toJsonTree(GeneratorUtils.generatorInt32Value(2)));
                 rowObject.add("arr_float_field", JsonUtils.toJsonTree(GeneratorUtils.generatorFloatValue(3)));
                 rowObject.add("arr_varchar_field", JsonUtils.toJsonTree(GeneratorUtils.generatorVarcharValue(4, 5)));
@@ -421,6 +428,12 @@ public class BulkWriterTest {
                 // set incorrect type for varchar field, expect throwing an exception
                 rowObject.addProperty("float_field", 2.5);
                 rowObject.addProperty("varchar_field", 2.5);
+//                localBulkWriter.appendRow(rowObject);
+                Assertions.assertThrows(MilvusException.class, ()->localBulkWriter.appendRow(rowObject));
+
+                // set incorrect value type for int8 vector field, expect throwing an exception
+                rowObject.addProperty("varchar_field", "dummy");
+                rowObject.addProperty("int8_vector_field", Boolean.TRUE);
 //                localBulkWriter.appendRow(rowObject);
                 Assertions.assertThrows(MilvusException.class, ()->localBulkWriter.appendRow(rowObject));
             } catch (Exception e) {

+ 4 - 0
sdk-bulkwriter/src/test/java/io/milvus/bulkwriter/TestUtils.java

@@ -58,6 +58,10 @@ public class TestUtils {
 
     }
 
+    public ByteBuffer generateInt8Vector() {
+        return generateBinaryVector(dimension*8);
+    }
+
     public ByteBuffer generateFloat16Vector() {
         List<Float> vector = generateFloatVector();
         return Float16Utils.f32VectorToFp16Buffer(vector);

+ 1 - 0
sdk-core/src/main/java/io/milvus/param/ParamUtils.java

@@ -90,6 +90,7 @@ public class ParamUtils {
         typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be JsonArray of byte[].");
         typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be JsonArray of byte[].");
         typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be JsonObject of Map<Long, Float>.");
+        typeErrMsg.put(DataType.Int8Vector, "Type mismatch for field '%s': Int8Vector vector field's value type must be JsonArray of byte[].");
         return typeErrMsg;
     }