Browse Source

Fix a bug of SearchResultsWrapper.getRowRecords() that returns wrong data for output fields (#1443)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 week ago
parent
commit
210e93f897

+ 12 - 1
examples/src/main/java/io/milvus/v1/CommonUtils.java

@@ -75,6 +75,18 @@ public class CommonUtils {
         return vectors;
     }
 
+    public static void compareFloatVectors(List<Float> vec1, List<Float> vec2) {
+        if (vec1.size() != vec2.size()) {
+            throw new RuntimeException(String.format("Vector dimension mismatch: %d vs %d", vec1.size(), vec2.size()));
+        }
+        for (int i = 0; i < vec1.size(); i++) {
+            if (Math.abs(vec1.get(i) - vec2.get(i)) > 0.001f) {
+                throw new RuntimeException(String.format("Vector value mismatch: %f vs %f at No.%d value",
+                        vec1.get(i), vec2.get(i), i));
+            }
+        }
+    }
+
     /////////////////////////////////////////////////////////////////////////////////////////////////////
     public static ByteBuffer generateBinaryVector(int dimension) {
         Random ran = new Random();
@@ -302,5 +314,4 @@ public class CommonUtils {
         }
         return vectors;
     }
-
 }

+ 73 - 4
examples/src/main/java/io/milvus/v1/JsonFieldExample.java

@@ -26,25 +26,30 @@ import io.milvus.client.MilvusServiceClient;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.grpc.DataType;
 import io.milvus.grpc.QueryResults;
+import io.milvus.grpc.SearchResults;
 import io.milvus.param.*;
 import io.milvus.param.collection.*;
 import io.milvus.param.dml.InsertParam;
 import io.milvus.param.dml.QueryParam;
+import io.milvus.param.dml.SearchParam;
 import io.milvus.param.index.CreateIndexParam;
 import io.milvus.response.QueryResultsWrapper;
+import io.milvus.response.SearchResultsWrapper;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
 public class JsonFieldExample {
     private static final String COLLECTION_NAME = "java_sdk_example_json_v1";
-    private static final String ID_FIELD = "id";
+    private static final String ID_FIELD = "key";
     private static final String VECTOR_FIELD = "vector";
     private static final String JSON_FIELD = "metadata";
     private static final Integer VECTOR_DIM = 128;
 
     private static void queryWithExpr(MilvusClient client, String expr) {
+        System.out.printf("%n=============================Query with expr: '%s'================================%n", expr);
         R<QueryResults> queryRet = client.query(QueryParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withExpr(expr)
@@ -56,7 +61,6 @@ public class JsonFieldExample {
         for (QueryResultsWrapper.RowRecord record : records) {
             System.out.println(record);
         }
-        System.out.println("=============================================================");
     }
 
     public static void main(String[] args) {
@@ -123,22 +127,28 @@ public class JsonFieldExample {
         System.out.println("Collection created");
 
         // insert rows
+        List<List<Float>> vectors = new ArrayList<>();
+        List<JsonObject> metadatas = new ArrayList<>();
         Gson gson = new Gson();
         for (int i = 0; i < 100; i++) {
             JsonObject row = new JsonObject();
             row.addProperty(ID_FIELD, i);
-            row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
+            List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
+            row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+            vectors.add(vector);
 
             // Note: for JSON field, always construct a real JsonObject
             // don't use row.addProperty(JSON_FIELD, strContent) since the value is treated as a string, not a JsonObject
             JsonObject metadata = new JsonObject();
-            metadata.addProperty("path", String.format("\\root/abc/path%d", i));
+            metadata.addProperty("path", String.format("\\root/abc/path_%d", i));
             metadata.addProperty("size", i);
             if (i%7 == 0) {
                 metadata.addProperty("special", true);
             }
+
             metadata.add("flags", gson.toJsonTree(Arrays.asList(i, i + 1, i + 2)));
             row.add(JSON_FIELD, metadata);
+            metadatas.add(metadata);
 //            System.out.println(metadata);
 
             // dynamic fields
@@ -165,6 +175,65 @@ public class JsonFieldExample {
         long rowCount = (long)queryWrapper.getFieldWrapper("count(*)").getFieldData().get(0);
         System.out.printf("%d rows persisted\n", rowCount);
 
+        // search and output JSON field
+        List<List<Float>> searchVectors = new ArrayList<>();
+        List<JsonObject> expectedMetadatas = new ArrayList<>();
+        for (int i = 0; i < 10; i++) {
+            List<Float> targetVector = vectors.get(i);
+            searchVectors.add(targetVector);
+            expectedMetadatas.add(metadatas.get(i));
+        }
+        R<SearchResults> searchRet = client.search(SearchParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withLimit(3L)
+                .withFloatVectors(searchVectors)
+                .withVectorFieldName(VECTOR_FIELD)
+                .addOutField(ID_FIELD)
+                .addOutField(VECTOR_FIELD)
+                .addOutField(JSON_FIELD)
+                .build());
+        CommonUtils.handleResponseStatus(searchRet);
+
+        SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
+        System.out.println("\n=============================Search result with IDScore================================");
+        for (int i = 0; i < 10; i++) {
+            List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(i);
+            System.out.printf("\nThe result of No.%d target vector:\n", i);
+            for (SearchResultsWrapper.IDScore score : scores) {
+                System.out.println(score);
+            }
+            long pk = scores.get(0).getLongID();
+            if (pk != i) {
+                throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", pk, i));
+            }
+            JsonObject metadata = (JsonObject) scores.get(0).get(JSON_FIELD);
+            if (!metadata.equals(expectedMetadatas.get(i))) {
+                throw new RuntimeException(String.format("The top1 metadata %s is not equal to target metadata %s",
+                        metadata, expectedMetadatas.get(i)));
+            }
+            List<Float> vector = (List<Float>) scores.get(0).get(VECTOR_FIELD);
+            CommonUtils.compareFloatVectors(vector, searchVectors.get(i));
+        }
+        System.out.println("\n=============================Search result with RowRecord================================");
+        for (int i = 0; i < 10; i++) {
+            List<QueryResultsWrapper.RowRecord> records = resultsWrapper.getRowRecords(i);
+            System.out.printf("\nThe result of No.%d target vector:\n", i);
+            for (QueryResultsWrapper.RowRecord record : records) {
+                System.out.println(record);
+            }
+            long pk = (long)records.get(0).get(ID_FIELD);
+            if (pk != i) {
+                throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", pk, i));
+            }
+            JsonObject metadata = (JsonObject) records.get(0).get(JSON_FIELD);
+            if (!metadata.equals(expectedMetadatas.get(i))) {
+                throw new RuntimeException(String.format("The top1 metadata %s is not equal to target metadata %s",
+                        metadata, expectedMetadatas.get(i)));
+            }
+            List<Float> vector = (List<Float>) records.get(0).get(VECTOR_FIELD);
+            CommonUtils.compareFloatVectors(vector, searchVectors.get(i));
+        }
+
         // query by filtering JSON
         queryWithExpr(client, "exists metadata[\"special\"]");
         queryWithExpr(client, "metadata[\"size\"] < 5");

+ 51 - 4
examples/src/main/java/io/milvus/v2/JsonFieldExample.java

@@ -32,18 +32,23 @@ import io.milvus.v2.service.collection.request.CreateCollectionReq;
 import io.milvus.v2.service.collection.request.DropCollectionReq;
 import io.milvus.v2.service.vector.request.InsertReq;
 import io.milvus.v2.service.vector.request.QueryReq;
+import io.milvus.v2.service.vector.request.SearchReq;
+import io.milvus.v2.service.vector.request.data.BaseVector;
+import io.milvus.v2.service.vector.request.data.FloatVec;
 import io.milvus.v2.service.vector.response.QueryResp;
+import io.milvus.v2.service.vector.response.SearchResp;
 
 import java.util.*;
 
 public class JsonFieldExample {
     private static final String COLLECTION_NAME = "java_sdk_example_json_v2";
-    private static final String ID_FIELD = "id";
+    private static final String ID_FIELD = "key";
     private static final String VECTOR_FIELD = "vector";
     private static final String JSON_FIELD = "metadata";
     private static final Integer VECTOR_DIM = 128;
 
     private static void queryWithExpr(MilvusClientV2 client, String expr) {
+        System.out.printf("%n=============================Query with expr: '%s'================================%n", expr);
         QueryResp queryRet = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .filter(expr)
@@ -54,7 +59,6 @@ public class JsonFieldExample {
         for (QueryResp.QueryResult record : records) {
             System.out.println(record.getEntity());
         }
-        System.out.println("=============================================================");
     }
 
     public static void main(String[] args) {
@@ -104,22 +108,27 @@ public class JsonFieldExample {
         System.out.println("Collection created");
 
         // Insert rows
+        List<List<Float>> vectors = new ArrayList<>();
+        List<JsonObject> metadatas = new ArrayList<>();
         Gson gson = new Gson();
         for (int i = 0; i < 100; i++) {
             JsonObject row = new JsonObject();
             row.addProperty(ID_FIELD, i);
-            row.add(VECTOR_FIELD, gson.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
+            List<Float> vector = CommonUtils.generateFloatVector(VECTOR_DIM);
+            row.add(VECTOR_FIELD, gson.toJsonTree(vector));
+            vectors.add(vector);
 
             // Note: for JSON field, always construct a real JsonObject
             // don't use row.addProperty(JSON_FIELD, strContent) since the value is treated as a string, not a JsonObject
             JsonObject metadata = new JsonObject();
-            metadata.addProperty("path", String.format("\\root/abc/path%d", i));
+            metadata.addProperty("path", String.format("\\root/abc/path_%d", i));
             metadata.addProperty("size", i);
             if (i%7 == 0) {
                 metadata.addProperty("special", true);
             }
             metadata.add("flags", gson.toJsonTree(Arrays.asList(i, i + 1, i + 2)));
             row.add(JSON_FIELD, metadata);
+            metadatas.add(metadata);
 //            System.out.println(metadata);
 
             // dynamic fields
@@ -144,6 +153,44 @@ public class JsonFieldExample {
                 .build());
         System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
 
+        // Search and output JSON field
+        List<BaseVector> searchVectors = new ArrayList<>();
+        List<JsonObject> expectedMetadatas = new ArrayList<>();
+        for (int i = 0; i < 10; i++) {
+            List<Float> targetVector = vectors.get(i);
+            searchVectors.add(new FloatVec(targetVector));
+            expectedMetadatas.add(metadatas.get(i));
+        }
+        SearchResp searchRet = client.search(SearchReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(searchVectors)
+                .limit(3L)
+                .annsField(VECTOR_FIELD)
+                .outputFields(Arrays.asList(ID_FIELD, VECTOR_FIELD, JSON_FIELD))
+                .build());
+
+        System.out.println("\n=============================Search result================================");
+        List<List<SearchResp.SearchResult>> searchResults = searchRet.getSearchResults();
+        for (int i = 0; i < 10; i++) {
+            List<SearchResp.SearchResult> results = searchResults.get(i);
+            System.out.printf("\nThe result of No.%d target vector:\n", i);
+            for (SearchResp.SearchResult result : results) {
+                System.out.println(result);
+            }
+
+            long pk = (long)results.get(0).getId();
+            if (pk != i) {
+                throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d", pk, i));
+            }
+            JsonObject metadata = (JsonObject) results.get(0).getEntity().get(JSON_FIELD);
+            if (!metadata.equals(expectedMetadatas.get(i))) {
+                throw new RuntimeException(String.format("The top1 metadata %s is not equal to target metadata %s",
+                        metadata, expectedMetadatas.get(i)));
+            }
+            List<Float> vector = (List<Float>) results.get(0).getEntity().get(VECTOR_FIELD);
+            CommonUtils.compareFloatVectors(vector, (List<Float>)searchVectors.get(i).getData());
+        }
+
         // Query by filtering JSON
         queryWithExpr(client, "exists metadata[\"special\"]");
         queryWithExpr(client, "metadata[\"size\"] < 5");

+ 8 - 6
sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -40,8 +40,11 @@ import java.util.Map;
 public class SearchResultsWrapper extends RowRecordWrapper {
     private final SearchResultData results;
 
+    private String primaryKey = "id";
+
     public SearchResultsWrapper(@NonNull SearchResultData results) {
         this.results = results;
+        this.primaryKey = results.getPrimaryFieldName();
     }
 
     /**
@@ -86,13 +89,13 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             IDScore score = idScore.get(i);
             QueryResultsWrapper.RowRecord record = new QueryResultsWrapper.RowRecord();
             if (score.getStrID().isEmpty()) {
-                record.put("id", score.getLongID());
+                record.put(primaryKey, score.getLongID());
             } else {
-                record.put("id", score.getStrID());
+                record.put(primaryKey, score.getStrID());
             }
 
             record.put("score", score.getScore()); // use score instead
-            buildRowRecord(record, i);
+            buildRowRecord(record, indexOfTarget*topK + (long)i);
             records.add(record);
         }
         return records;
@@ -162,7 +165,6 @@ public class SearchResultsWrapper extends RowRecordWrapper {
 
         // set id and score
         IDs ids = results.getIds();
-        String pkName = results.getPrimaryFieldName();
         if (ids.hasIntId()) {
             LongArray longIDs = ids.getIntId();
             if (offset + k > longIDs.getDataCount()) {
@@ -170,7 +172,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             }
 
             for (int n = 0; n < k; ++n) {
-                idScores.add(new IDScore(pkName, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
+                idScores.add(new IDScore(primaryKey, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
             }
         } else if (ids.hasStrId()) {
             StringArray strIDs = ids.getStrId();
@@ -179,7 +181,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             }
 
             for (int n = 0; n < k; ++n) {
-                idScores.add(new IDScore(pkName, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
+                idScores.add(new IDScore(primaryKey, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
             }
         } else {
             // in v2.3.3, return an empty list instead of throwing exception

+ 11 - 0
sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java

@@ -678,6 +678,17 @@ class MilvusClientDockerTest {
             for (int k = 0; k < outputVec.size(); k++) {
                 Assertions.assertEquals(targetVectors.get(i).get(k), outputVec.get(k));
             }
+
+            // verify the old way
+            List<QueryResultsWrapper.RowRecord> records = results.getRowRecords(i);
+            obj = records.get(0).get(DataType.FloatVector.name());
+            outputVec = (List<Float>)obj;
+            Assertions.assertEquals(targetVectors.get(i).size(), outputVec.size());
+            for (int k = 0; k < outputVec.size(); k++) {
+                Assertions.assertEquals(targetVectors.get(i).get(k), outputVec.get(k));
+            }
+            double d = (double)records.get(0).get(DataType.Double.name());
+            Assertions.assertEquals(d, compareWeights.get(i));
         }
 
         List<?> fieldData = results.getFieldData(DataType.Double.name(), 0);