Browse Source

Fix a bug of binary vector (#1152)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 8 months ago
parent
commit
df3ce5c686

+ 1 - 1
src/main/java/io/milvus/param/ParamUtils.java

@@ -154,7 +154,7 @@ public class ParamUtils {
 
                     // check dimension
                     ByteBuffer v = (ByteBuffer)value;
-                    int real_dim = calculateBinVectorDim(dataType, v.position());
+                    int real_dim = calculateBinVectorDim(dataType, v.limit());
                     if (real_dim != dim) {
                         String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d";
                         throw new ParamException(String.format(msg, fieldSchema.getName(), i, real_dim, dim));

+ 2 - 2
src/main/java/io/milvus/param/QueryNodeSingleSearch.java

@@ -172,10 +172,10 @@ public class QueryNodeSingleSearch {
             } else if (vectors.get(0) instanceof ByteBuffer) {
                 // binary vectors
                 ByteBuffer first = (ByteBuffer) vectors.get(0);
-                int dim = first.position();
+                int dim = first.limit();
                 for (int i = 1; i < vectors.size(); ++i) {
                     ByteBuffer temp = (ByteBuffer) vectors.get(i);
-                    if (dim != temp.position()) {
+                    if (dim != temp.limit()) {
                         throw new ParamException("Target vector dimension must be equal");
                     }
                 }

+ 2 - 2
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -439,10 +439,10 @@ public class SearchParam {
             // BinaryVector/Float16Vector/BFloatVector
             // TODO: here only check the first element, potential risk
             ByteBuffer first = (ByteBuffer) vectors.get(0);
-            int len = first.position();
+            int len = first.limit();
             for (int i = 1; i < vectors.size(); ++i) {
                 ByteBuffer temp = (ByteBuffer) vectors.get(i);
-                if (len != temp.position()) {
+                if (len != temp.limit()) {
                     throw new ParamException("Target vector dimension must be equal");
                 }
             }

+ 23 - 3
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -895,6 +895,21 @@ class MilvusClientDockerTest {
         List<Long> ids2 = insertResultWrapper.getLongIDs(); // get returned IDs(generated by server-side)
         Assertions.assertEquals(rowCount, ids2.size());
 
+        // insert test vector, position() is zero with ByteBuffer.wrap()
+        byte[] byteArray = new byte[DIMENSION/8];
+        for (int i = 0; i < byteArray.length; i++) {
+            byteArray[i] = (byte) ((i%3 == 0) ? 255 : 0);
+        }
+        ByteBuffer testBuffer = ByteBuffer.wrap(byteArray);
+        List<InsertParam.Field> testData =
+                Collections.singletonList(new InsertParam.Field(DataType.BinaryVector.name(), Collections.singletonList(testBuffer)));
+        R<MutationResult> insertR3 = client.insert(InsertParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withFields(testData)
+                .build());
+        insertResultWrapper = new MutationResultWrapper(insertR3.getData());
+        Long testID = insertResultWrapper.getLongIDs().get(0);
+
         // get collection statistics
         R<GetCollectionStatisticsResponse> statR = client.getCollectionStatistics(GetCollectionStatisticsParam
                 .newBuilder()
@@ -905,7 +920,7 @@ class MilvusClientDockerTest {
 
         GetCollStatResponseWrapper stat = new GetCollStatResponseWrapper(statR.getData());
         System.out.println("Collection row count: " + stat.getRowCount());
-        Assertions.assertEquals(2*rowCount, stat.getRowCount());
+        Assertions.assertEquals(2*rowCount+1, stat.getRowCount());
 
         // check index
         while(true) {
@@ -927,8 +942,8 @@ class MilvusClientDockerTest {
             Assertions.assertEquals(DataType.BinaryVector.name(), indexDesc.getFieldName());
             Assertions.assertEquals(IndexType.BIN_IVF_FLAT, indexDesc.getIndexType());
             Assertions.assertEquals(MetricType.JACCARD, indexDesc.getMetricType());
-            Assertions.assertEquals(2*rowCount, indexDesc.getTotalRows());
-            Assertions.assertEquals(2*rowCount, indexDesc.getIndexedRows());
+            Assertions.assertEquals(2*rowCount+1, indexDesc.getTotalRows());
+            Assertions.assertEquals(2*rowCount+1, indexDesc.getIndexedRows());
             Assertions.assertEquals(0L, indexDesc.getPendingIndexRows());
             Assertions.assertTrue(indexDesc.getIndexFailedReason().isEmpty());
             System.out.println("Index description: " + indexDesc);
@@ -1026,6 +1041,8 @@ class MilvusClientDockerTest {
             targetVectorIDs.add(ids1.get(i));
             targetVectors.add((ByteBuffer) columnsData.get(0).getValues().get(i));
         }
+        targetVectors.add(testBuffer);
+        targetVectorIDs.add(testID);
 
         int topK = 5;
         SearchParam searchParam = SearchParam.newBuilder()
@@ -1035,6 +1052,7 @@ class MilvusClientDockerTest {
                 .withBinaryVectors(targetVectors)
                 .withVectorFieldName(DataType.BinaryVector.name())
                 .withParams("{\"nprobe\":8}")
+                .withOutFields(Collections.singletonList(DataType.BinaryVector.name()))
                 .build();
 
         R<SearchResults> searchR = client.search(searchParam);
@@ -1048,6 +1066,8 @@ class MilvusClientDockerTest {
             System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
             System.out.println(scores);
             Assertions.assertEquals(targetVectorIDs.get(i), scores.get(0).getLongID());
+            ByteBuffer buf = (ByteBuffer) scores.get(0).get(DataType.BinaryVector.name());
+            Assertions.assertArrayEquals(targetVectors.get(i).array(), buf.array());
         }
 
         // drop collection

+ 6 - 0
src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -637,10 +637,12 @@ class MilvusClientV2DockerTest {
         int topk = 10;
         List<Long> targetIDs = new ArrayList<>();
         List<BaseVector> targetVectors = new ArrayList<>();
+        List<ByteBuffer> targetOriginVectors = new ArrayList<>();
         for (int i = 0; i < nq; i++) {
             JsonObject row = data.get(RANDOM.nextInt((int)count));
             targetIDs.add(row.get("id").getAsLong());
             byte[] vector = JsonUtils.fromJson(row.get(vectorFieldName), new TypeToken<byte[]>() {}.getType());
+            targetOriginVectors.add(ByteBuffer.wrap(vector));
             targetVectors.add(new BinaryVec(vector));
         }
         SearchResp searchResp = client.search(SearchReq.builder()
@@ -648,6 +650,7 @@ class MilvusClientV2DockerTest {
                 .annsField(vectorFieldName)
                 .data(targetVectors)
                 .topK(10)
+                .outputFields(Collections.singletonList(vectorFieldName))
                 .build());
         List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
         Assertions.assertEquals(nq, searchResults.size());
@@ -655,6 +658,9 @@ class MilvusClientV2DockerTest {
             List<SearchResp.SearchResult> results = searchResults.get(i);
             Assertions.assertEquals(topk, results.size());
             Assertions.assertEquals(targetIDs.get(i), results.get(0).getId());
+
+            ByteBuffer buf = (ByteBuffer) results.get(0).getEntity().get(vectorFieldName);
+            Assertions.assertArrayEquals(targetOriginVectors.get(i).array(), buf.array());
         }
 
         client.dropCollection(DropCollectionReq.builder().collectionName(randomCollectionName).build());