Browse Source

Verify retrieve (#914)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 year ago
parent
commit
45d7fa2c67

+ 2 - 2
examples/main/java/io/milvus/v1/CommonUtils.java

@@ -96,10 +96,10 @@ public class CommonUtils {
         for (int i = 0; i < dimension; ++i) {
         for (int i = 0; i < dimension; ++i) {
             ByteDataBuffer bf;
             ByteDataBuffer bf;
             if (bfloat16) {
             if (bfloat16) {
-                TFloat16 tt = TFloat16.scalarOf((float)ran.nextInt(dimension));
+                TBfloat16 tt = TBfloat16.scalarOf((float)ran.nextInt(dimension));
                 bf = tt.asRawTensor().data();
                 bf = tt.asRawTensor().data();
             } else {
             } else {
-                TBfloat16 tt = TBfloat16.scalarOf((float)ran.nextInt(dimension));
+                TFloat16 tt = TFloat16.scalarOf((float)ran.nextInt(dimension));
                 bf = tt.asRawTensor().data();
                 bf = tt.asRawTensor().data();
             }
             }
             vector.put(bf.getByte(0));
             vector.put(bf.getByte(0));

+ 90 - 0
examples/main/java/io/milvus/v1/Float16VectorExample.java

@@ -32,6 +32,13 @@ import io.milvus.response.*;
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
 import java.util.*;
 import java.util.*;
 
 
+import org.tensorflow.Tensor;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.buffer.ByteDataBuffer;
+import org.tensorflow.ndarray.buffer.DataBuffers;
+import org.tensorflow.types.TBfloat16;
+import org.tensorflow.types.TFloat16;
+
 
 
 public class Float16VectorExample {
 public class Float16VectorExample {
     private static final String COLLECTION_NAME = "java_sdk_example_float16";
     private static final String COLLECTION_NAME = "java_sdk_example_float16";
@@ -202,6 +209,44 @@ public class Float16VectorExample {
         }
         }
         System.out.println("Query result is correct");
         System.out.println("Query result is correct");
 
 
+        // insert a single row
+        JsonObject row = new JsonObject();
+        row.addProperty(ID_FIELD, 9999999);
+        List<Float> newVector = CommonUtils.generateFloatVector(VECTOR_DIM);
+        ByteBuffer vector16Buf = encodeTF(newVector, bfloat16);
+        row.add(VECTOR_FIELD, gson.toJsonTree(vector16Buf.array()));
+        insertR = milvusClient.insert(InsertParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withRows(Collections.singletonList(row))
+                .build());
+        CommonUtils.handleResponseStatus(insertR);
+
+        // retrieve the single row
+        queryR = milvusClient.query(QueryParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withExpr("id == 9999999")
+                .addOutField(VECTOR_FIELD)
+                .build());
+        CommonUtils.handleResponseStatus(queryR);
+        queryWrapper = new QueryResultsWrapper(queryR.getData());
+        field = queryWrapper.getFieldWrapper(VECTOR_FIELD);
+        r = field.getFieldData();
+        if (r.isEmpty()) {
+            throw new RuntimeException("The retrieve result is empty");
+        } else {
+            ByteBuffer outBuf = (ByteBuffer) r.get(0);
+            List<Float> outVector = decodeTF(outBuf, bfloat16);
+            if (outVector.size() != newVector.size()) {
+                throw new RuntimeException("The retrieve result is incorrect");
+            }
+            for (int i = 0; i < outVector.size(); i++) {
+                if (!isFloat16Eauql(outVector.get(i), newVector.get(i), bfloat16)) {
+                    throw new RuntimeException("The retrieve result is incorrect");
+                }
+            }
+        }
+        System.out.println("Retrieve result is correct");
+
         // drop the collection if you don't need the collection anymore
         // drop the collection if you don't need the collection anymore
         milvusClient.dropCollection(DropCollectionParam.newBuilder()
         milvusClient.dropCollection(DropCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withCollectionName(COLLECTION_NAME)
@@ -211,6 +256,51 @@ public class Float16VectorExample {
         milvusClient.close();
         milvusClient.close();
     }
     }
 
 
+    private static ByteBuffer encodeTF(List<Float> vector, boolean bfloat16) {
+        ByteBuffer buf = ByteBuffer.allocate(vector.size() * 2);
+        for (Float value : vector) {
+            ByteDataBuffer bf;
+            if (bfloat16) {
+                TBfloat16 tt = TBfloat16.scalarOf(value);
+                bf = tt.asRawTensor().data();
+            } else {
+                TFloat16 tt = TFloat16.scalarOf(value);
+                bf = tt.asRawTensor().data();
+            }
+            buf.put(bf.getByte(0));
+            buf.put(bf.getByte(1));
+        }
+        return buf;
+    }
+
+    private static List<Float> decodeTF(ByteBuffer buf, boolean bfloat16) {
+        int dim = buf.limit()/2;
+        ByteDataBuffer bf = DataBuffers.of(buf.array());
+        List<Float> vec = new ArrayList<>();
+        if (bfloat16) {
+            TBfloat16 tf = Tensor.of(TBfloat16.class, Shape.of(dim), bf);
+            for (long k = 0; k < tf.size(); k++) {
+                vec.add(tf.getFloat(k));
+            }
+        } else {
+            TFloat16 tf = Tensor.of(TFloat16.class, Shape.of(dim), bf);
+            for (long k = 0; k < tf.size(); k++) {
+                vec.add(tf.getFloat(k));
+            }
+        }
+
+        return vec;
+    }
+
+    private static boolean isFloat16Eauql(Float a, Float b, boolean bfloat16) {
+        if (bfloat16) {
+            return Math.abs(a - b) <= 0.01f;
+        } else {
+            return Math.abs(a - b) <= 0.001f;
+        }
+    }
+
+
     public static void main(String[] args) {
     public static void main(String[] args) {
         testFloat16(true);
         testFloat16(true);
         testFloat16(false);
         testFloat16(false);

+ 77 - 4
src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -354,6 +354,27 @@ class MilvusClientDockerTest {
         ShowPartResponseWrapper infoPart = new ShowPartResponseWrapper(showPartR.getData());
         ShowPartResponseWrapper infoPart = new ShowPartResponseWrapper(showPartR.getData());
         System.out.println("Partition info: " + infoPart.toString());
         System.out.println("Partition info: " + infoPart.toString());
 
 
+        // query
+        Long fetchID = ids.get(0);
+        List<Float> fetchVector = vectors.get(0);
+        R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr(String.format("%s == %d", field1Name, fetchID))
+                .addOutField(field2Name)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
+        QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
+        FieldDataWrapper fetchField = fetchWrapper.getFieldWrapper(field2Name);
+        Assertions.assertEquals(1L, fetchField.getRowCount());
+        List<?> fetchObj = fetchField.getFieldData();
+        Assertions.assertEquals(1, fetchObj.size());
+        Assertions.assertInstanceOf(List.class, fetchObj.get(0));
+        List<Float> fetchResult = (List<Float>) fetchObj.get(0);
+        Assertions.assertEquals(fetchVector.size(), fetchResult.size());
+        for (int i = 0; i < fetchResult.size(); i++) {
+            Assertions.assertEquals(fetchVector.get(i), fetchResult.get(i));
+        }
+
         // query vectors to verify
         // query vectors to verify
         List<Long> queryIDs = new ArrayList<>();
         List<Long> queryIDs = new ArrayList<>();
         List<Double> compareWeights = new ArrayList<>();
         List<Double> compareWeights = new ArrayList<>();
@@ -450,6 +471,7 @@ class MilvusClientDockerTest {
                 .withVectorFieldName(field2Name)
                 .withVectorFieldName(field2Name)
                 .withParams("{\"ef\":64}")
                 .withParams("{\"ef\":64}")
                 .addOutField(field4Name)
                 .addOutField(field4Name)
+                .addOutField(field2Name)
                 .build();
                 .build();
 
 
         R<SearchResults> searchR = client.search(searchParam);
         R<SearchResults> searchR = client.search(searchParam);
@@ -462,7 +484,15 @@ class MilvusClientDockerTest {
             List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
             List<SearchResultsWrapper.IDScore> scores = results.getIDScore(i);
             System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
             System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):");
             System.out.println(scores);
             System.out.println(scores);
-            Assertions.assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID());
+            Assertions.assertEquals(targetVectorIDs.get(i), scores.get(0).getLongID());
+
+            Object obj = scores.get(0).get(field2Name);
+            Assertions.assertInstanceOf(List.class, obj);
+            List<Float> outputVec = (List<Float>)obj;
+            Assertions.assertEquals(targetVectors.get(i).size(), outputVec.size());
+            for (int k = 0; k < outputVec.size(); k++) {
+                Assertions.assertEquals(targetVectors.get(i).get(k), outputVec.get(k));
+            }
         }
         }
 
 
         List<?> fieldData = results.getFieldData(field4Name, 0);
         List<?> fieldData = results.getFieldData(field4Name, 0);
@@ -597,6 +627,24 @@ class MilvusClientDockerTest {
                 .build());
                 .build());
         Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
         Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
 
 
+        // query
+        Long fetchID = ids1.get(0);
+        ByteBuffer fetchVector = vectors.get(0);
+        R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr(String.format("%s == %d", field1Name, fetchID))
+                .addOutField(field2Name)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
+        QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
+        FieldDataWrapper fetchField = fetchWrapper.getFieldWrapper(field2Name);
+        Assertions.assertEquals(1L, fetchField.getRowCount());
+        List<?> fetchObj = fetchField.getFieldData();
+        Assertions.assertEquals(1, fetchObj.size());
+        Assertions.assertInstanceOf(ByteBuffer.class, fetchObj.get(0));
+        ByteBuffer fetchBuffer = (ByteBuffer) fetchObj.get(0);
+        Assertions.assertArrayEquals(fetchVector.array(), fetchBuffer.array());
+
         // search with BIN_FLAT index
         // search with BIN_FLAT index
         int searchTarget = 99;
         int searchTarget = 99;
         List<ByteBuffer> oneVector = new ArrayList<>();
         List<ByteBuffer> oneVector = new ArrayList<>();
@@ -623,9 +671,7 @@ class MilvusClientDockerTest {
         List<?> items = oneResult.getFieldData(field2Name, 0);
         List<?> items = oneResult.getFieldData(field2Name, 0);
         Assertions.assertEquals(items.size(), 5);
         Assertions.assertEquals(items.size(), 5);
         ByteBuffer firstItem = (ByteBuffer) items.get(0);
         ByteBuffer firstItem = (ByteBuffer) items.get(0);
-        for (int i = 0; i < firstItem.limit(); ++i) {
-            Assertions.assertEquals(firstItem.get(i), vectors.get(searchTarget).get(i));
-        }
+        Assertions.assertArrayEquals(vectors.get(searchTarget).array(), firstItem.array());
 
 
         // release collection
         // release collection
         ReleaseCollectionParam releaseCollectionParam = ReleaseCollectionParam.newBuilder()
         ReleaseCollectionParam releaseCollectionParam = ReleaseCollectionParam.newBuilder()
@@ -773,6 +819,28 @@ class MilvusClientDockerTest {
                 .build());
                 .build());
         Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
         Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue());
 
 
+        // query
+        Long fetchID = ids.get(0);
+        SortedMap<Long, Float> fetchVector = vectors.get(0);
+        R<QueryResults> fetchR = client.query(QueryParam.newBuilder()
+                .withCollectionName(randomCollectionName)
+                .withExpr(String.format("%s == %d", field1Name, fetchID))
+                .addOutField(field2Name)
+                .build());
+        Assertions.assertEquals(R.Status.Success.getCode(), fetchR.getStatus().intValue());
+        QueryResultsWrapper fetchWrapper = new QueryResultsWrapper(fetchR.getData());
+        FieldDataWrapper fetchField = fetchWrapper.getFieldWrapper(field2Name);
+        Assertions.assertEquals(1L, fetchField.getRowCount());
+        List<?> fetchObj = fetchField.getFieldData();
+        Assertions.assertEquals(1, fetchObj.size());
+        Assertions.assertInstanceOf(SortedMap.class, fetchObj.get(0));
+        SortedMap<Long, Float> fetchSparse = (SortedMap<Long, Float>) fetchObj.get(0);
+        Assertions.assertEquals(fetchVector.size(), fetchSparse.size());
+        for (Long key : fetchVector.keySet()) {
+            Assertions.assertTrue(fetchSparse.containsKey(key));
+            Assertions.assertEquals(fetchVector.get(key), fetchSparse.get(key));
+        }
+
         // pick some vectors to search with index
         // pick some vectors to search with index
         int nq = 5;
         int nq = 5;
         List<Long> targetVectorIDs = new ArrayList<>();
         List<Long> targetVectorIDs = new ArrayList<>();
@@ -813,6 +881,11 @@ class MilvusClientDockerTest {
             Object v = scores.get(0).get(field2Name);
             Object v = scores.get(0).get(field2Name);
             SortedMap<Long, Float> sparse = (SortedMap<Long, Float>)v;
             SortedMap<Long, Float> sparse = (SortedMap<Long, Float>)v;
             Assertions.assertTrue(sparse.equals(targetVectors.get(i)));
             Assertions.assertTrue(sparse.equals(targetVectors.get(i)));
+            Assertions.assertEquals(targetVectors.get(i).size(), sparse.size());
+            for (Long key : sparse.keySet()) {
+                Assertions.assertTrue(targetVectors.get(i).containsKey(key));
+                Assertions.assertEquals(sparse.get(key), targetVectors.get(i).get(key));
+            }
         }
         }
 
 
         // drop collection
         // drop collection