Quellcode durchsuchen

Fix a bug of binary vectors (#609)

Signed-off-by: groot <yihua.mo@zilliz.com>
groot vor 1 Jahr
Ursprung
Commit
05d3f4df7f

+ 1 - 1
README.md

@@ -37,7 +37,7 @@ You can use **Apache Maven** or **Gradle**/**Grails** to download the SDK.
    - Gradle/Grails
 
         ```gradle
-        compile 'io.milvus:milvus-sdk-java:2.2.12'
+        implementation 'io.milvus:milvus-sdk-java:2.2.12'
         ```
 
 ### Examples

+ 7 - 6
src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -74,11 +74,11 @@ public class FieldDataWrapper {
             case BinaryVector: {
                 int dim = getDim();
                 ByteString data = fieldData.getVectors().getBinaryVector();
-                if (data.size() % dim != 0) {
+                if ((data.size()*8) % dim != 0) {
                     throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
                 }
 
-                return data.size()/dim;
+                return (data.size()*8)/dim;
             }
             case Int64:
                 return fieldData.getScalars().getLongData().getDataList().size();
@@ -138,15 +138,16 @@ public class FieldDataWrapper {
             case BinaryVector: {
                 int dim = getDim();
                 ByteString data = fieldData.getVectors().getBinaryVector();
-                if (data.size() % dim != 0) {
+                if ((data.size()*8) % dim != 0) {
                     throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
                 }
 
                 List<ByteBuffer> packData = new ArrayList<>();
-                int count = data.size() / dim;
+                int bytePerVec = dim/8;
+                int count = data.size()/bytePerVec;
                 for (int i = 0; i < count; ++i) {
-                    ByteBuffer bf = ByteBuffer.allocate(dim);
-                    bf.put(data.substring(i * dim, (i + 1) * dim).toByteArray());
+                    ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
+                    bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
                     packData.add(bf);
                 }
                 return packData;

+ 16 - 5
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -627,8 +627,9 @@ class MilvusClientDockerTest {
         CreateIndexParam indexParam2 = CreateIndexParam.newBuilder()
                 .withCollectionName(randomCollectionName)
                 .withFieldName(field2Name)
-                .withIndexType(IndexType.BIN_FLAT)
-                .withMetricType(MetricType.SUPERSTRUCTURE)
+                .withIndexType(IndexType.BIN_IVF_FLAT)
+                .withExtraParam("{\"nlist\":64}")
+                .withMetricType(MetricType.JACCARD)
                 .withSyncMode(Boolean.TRUE)
                 .withSyncWaitingInterval(500L)
                 .withSyncWaitingTimeout(30L)
@@ -644,15 +645,17 @@ class MilvusClientDockerTest {
         assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
 
         // search with BIN_FLAT index
+        int searchTarget = 99;
         List<ByteBuffer> oneVector = new ArrayList<>();
-        oneVector.add(vectors.get(0));
+        oneVector.add(vectors.get(searchTarget));
 
         SearchParam searchOneParam = SearchParam.newBuilder()
                 .withCollectionName(randomCollectionName)
-                .withMetricType(MetricType.SUPERSTRUCTURE)
+                .withMetricType(MetricType.JACCARD)
                 .withTopK(5)
                 .withVectors(oneVector)
                 .withVectorFieldName(field2Name)
+                .addOutField(field2Name)
                 .build();
 
         R<SearchResults> searchOne = client.search(searchOneParam);
@@ -660,9 +663,17 @@ class MilvusClientDockerTest {
 
         SearchResultsWrapper oneResult = new SearchResultsWrapper(searchOne.getData().getResults());
         List<SearchResultsWrapper.IDScore> oneScores = oneResult.getIDScore(0);
-        System.out.println("The result of " + ids.get(0) + " with SUPERSTRUCTURE metric:");
+        System.out.println("The search result of id " + ids.get(searchTarget) + " with SUPERSTRUCTURE metric:");
         System.out.println(oneScores);
 
+        // verify the output vector, the top1 item is equal to the target vector
+        List<?> items = oneResult.getFieldData(field2Name, 0);
+        Assertions.assertEquals(items.size(), 5);
+        ByteBuffer firstItem = (ByteBuffer) items.get(0);
+        for (int i = 0; i < firstItem.limit(); ++i) {
+            Assertions.assertEquals(firstItem.get(i), vectors.get(searchTarget).get(i));
+        }
+
         // release collection
         ReleaseCollectionParam releaseCollectionParam = ReleaseCollectionParam.newBuilder()
                 .withCollectionName(randomCollectionName).build();

+ 12 - 6
src/test/java/io/milvus/client/MilvusServiceClientTest.java

@@ -2747,7 +2747,9 @@ class MilvusServiceClientTest {
 
         // for binary vector
         dim = 16;
-        byte[] binary = new byte[(int) dim * 2];
+        int bytesPerVec = (int) (dim/8);
+        int count = 2;
+        byte[] binary = new byte[bytesPerVec * count];
         for (int i = 0; i < binary.length; ++i) {
             binary[i] = (byte) i;
         }
@@ -2763,13 +2765,17 @@ class MilvusServiceClientTest {
 
         wrapper = new FieldDataWrapper(fieldData);
         assertEquals(dim, wrapper.getDim());
-        assertEquals(binary.length / dim, wrapper.getRowCount());
+        assertEquals(count, wrapper.getRowCount());
 
         List<?> binaryData = wrapper.getFieldData();
-        assertEquals(binary.length / dim, binaryData.size());
-        for (Object obj : binaryData) {
-            ByteBuffer vec = (ByteBuffer) obj;
-            assertEquals(dim, vec.position());
+        assertEquals(count, binaryData.size());
+        for(int i = 0; i < binaryData.size(); i++) {
+            ByteBuffer vec = (ByteBuffer) binaryData.get(i);
+            assertEquals(bytesPerVec, vec.limit());
+
+            for(int j = 0; j < bytesPerVec; j++) {
+                assertEquals(binary[i*bytesPerVec + j], vec.get(j));
+            }
         }
 
         // for scalar field