Переглянути джерело

Reformat SearchResult/IDScore print content to show primary key (#1420)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 місяць тому
батько
коміт
cadbaeb9a4

+ 7 - 9
examples/src/main/java/io/milvus/v1/BinaryVectorExample.java

@@ -38,10 +38,10 @@ import java.util.*;
 
 public class BinaryVectorExample {
     private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v1";
-    private static final String ID_FIELD = "id";
+    private static final String ID_FIELD = "pk";
     private static final String VECTOR_FIELD = "vector";
 
-    private static final Integer VECTOR_DIM = 512;
+    private static final Integer VECTOR_DIM = 128;
     
 
     public static void main(String[] args) {
@@ -152,6 +152,8 @@ public class BinaryVectorExample {
             Random ran = new Random();
             int k = ran.nextInt(rowCount);
             ByteBuffer targetVector = vectors.get(k);
+            System.out.printf("\nANN search for vector ID=%d:\n", k);
+            CommonUtils.printBinaryVector(targetVector);
             R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
                     .withCollectionName(COLLECTION_NAME)
                     .withMetricType(MetricType.HAMMING)
@@ -169,13 +171,9 @@ public class BinaryVectorExample {
             List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
             System.out.printf("The result of No.%d target vector:\n", i);
             for (SearchResultsWrapper.IDScore score : scores) {
-                System.out.printf("ID: %d, Score: %f, Vector: ", score.getLongID(), score.getScore());
+                System.out.println(score);
                 ByteBuffer vector = (ByteBuffer)score.get(VECTOR_FIELD);
-                vector.rewind();
-                while (vector.hasRemaining()) {
-                    System.out.print(Integer.toBinaryString(vector.get()));
-                }
-                System.out.println();
+                CommonUtils.printBinaryVector(vector);
             }
             if (scores.get(0).getLongID() != k) {
                 throw new RuntimeException(String.format("The top1 ID %d is not equal to target vector's ID %d",
@@ -188,7 +186,7 @@ public class BinaryVectorExample {
         int n = 99;
         R<QueryResults> queryR = milvusClient.query(QueryParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
-                .withExpr(String.format("id == %d", n))
+                .withExpr(String.format("%s == %d", ID_FIELD, n))
                 .addOutField(VECTOR_FIELD)
                 .build());
         CommonUtils.handleResponseStatus(queryR);

+ 29 - 2
examples/src/main/java/io/milvus/v1/CommonUtils.java

@@ -96,6 +96,15 @@ public class CommonUtils {
         return vectors;
     }
 
+    public static void printBinaryVector(ByteBuffer vector) {
+        vector.rewind();
+        while (vector.hasRemaining()) {
+            String byteStr = String.format("%8s", Integer.toBinaryString(vector.get())).replace(' ', '0');
+            System.out.print(byteStr);
+        }
+        System.out.println();
+    }
+
     /////////////////////////////////////////////////////////////////////////////////////////////////////
     public static TBfloat16 genTensorflowBF16Vector(int dimension) {
         Random ran = new Random();
@@ -135,7 +144,7 @@ public class CommonUtils {
         return buffers;
     }
 
-    public static TBfloat16 decodeTensorBF16Vector(ByteBuffer buf) {
+    public static TBfloat16 decodeBF16VectorToTensor(ByteBuffer buf) {
         if (buf.limit()%2 != 0) {
             return null;
         }
@@ -144,6 +153,15 @@ public class CommonUtils {
         return Tensor.of(TBfloat16.class, Shape.of(dim), bf);
     }
 
+    public static List<Float> decodeBF16VectorToFloat(ByteBuffer buf) {
+        List<Float> vector = new ArrayList<>();
+        TBfloat16 tf = decodeBF16VectorToTensor(buf);
+        for (long i = 0; i < tf.size(); i++) {
+            vector.add(tf.getFloat(i));
+        }
+        return vector;
+    }
+
 
     public static TFloat16 genTensorflowFP16Vector(int dimension) {
         Random ran = new Random();
@@ -183,7 +201,7 @@ public class CommonUtils {
         return buffers;
     }
 
-    public static TFloat16 decodeTensorFP16Vector(ByteBuffer buf) {
+    public static TFloat16 decodeFP16VectorToTensor(ByteBuffer buf) {
         if (buf.limit()%2 != 0) {
             return null;
         }
@@ -192,6 +210,15 @@ public class CommonUtils {
         return Tensor.of(TFloat16.class, Shape.of(dim), bf);
     }
 
+    public static List<Float> decodeFP16VectorToFloat(ByteBuffer buf) {
+        List<Float> vector = new ArrayList<>();
+        TFloat16 tf = decodeFP16VectorToTensor(buf);
+        for (long i = 0; i < tf.size(); i++) {
+            vector.add(tf.getFloat(i));
+        }
+        return vector;
+    }
+
     /////////////////////////////////////////////////////////////////////////////////////////////////////
     public static ByteBuffer encodeFloat16Vector(List<Float> originVector, boolean bfloat16) {
         if (bfloat16) {

+ 7 - 13
examples/src/main/java/io/milvus/v1/Float16VectorExample.java

@@ -201,9 +201,6 @@ public class Float16VectorExample {
             SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(searchRet.getData().getResults());
             List<SearchResultsWrapper.IDScore> scores = resultsWrapper.getIDScore(0);
             System.out.printf("The result of No.%d target vector:\n", i);
-            for (SearchResultsWrapper.IDScore score : scores) {
-                System.out.println(score);
-            }
 
             SearchResultsWrapper.IDScore firstScore = scores.get(0);
             if (firstScore.getLongID() != k) {
@@ -223,6 +220,9 @@ public class Float16VectorExample {
                     throw new RuntimeException(String.format("The output vector is not equal to original vector: ID %d", k));
                 }
             }
+            System.out.println("\nTarget vector: " + originVector);
+            System.out.println("Top0 result: " + firstScore);
+            System.out.println("Top0 result vector: " + outputVector);
         }
         System.out.println("Search result is correct");
 
@@ -316,19 +316,13 @@ public class Float16VectorExample {
             throw new RuntimeException("The query result is incorrect");
         }
 
-        List<Float> vector = new ArrayList<>();
+        List<Float> outVector;
         if (bfloat16) {
-            TBfloat16 tf = CommonUtils.decodeTensorBF16Vector(outputBuf);
-            for (long i = 0; i < tf.size(); i++) {
-                vector.add(tf.getFloat(i));
-            }
+            outVector = CommonUtils.decodeBF16VectorToFloat(outputBuf);
         } else {
-            TFloat16 tf = CommonUtils.decodeTensorFP16Vector(outputBuf);
-            for (long i = 0; i < tf.size(); i++) {
-                vector.add(tf.getFloat(i));
-            }
+            outVector = CommonUtils.decodeFP16VectorToFloat(outputBuf);
         }
-        System.out.println(vector);
+        System.out.println("Output vector: " + outVector);
         System.out.println("Query result is correct");
 
         // drop the collection if you don't need the collection anymore

+ 3 - 3
examples/src/main/java/io/milvus/v1/GeneralExample.java

@@ -349,9 +349,9 @@ public class GeneralExample {
         for (int i = 0; i < vectors.size(); ++i) {
             System.out.println("Search result of No." + i);
             List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
-            System.out.println(scores);
-            System.out.println("Output field data for No." + i);
-            System.out.println(wrapper.getFieldData(AGE_FIELD, i));
+            for (SearchResultsWrapper.IDScore score : scores) {
+                System.out.println(score);
+            }
         }
 
         return response;

+ 1 - 0
examples/src/main/java/io/milvus/v1/SparseVectorExample.java

@@ -149,6 +149,7 @@ public class SparseVectorExample {
             Random ran = new Random();
             int k = ran.nextInt(rowCount);
             SortedMap<Long, Float> targetVector = vectors.get(k);
+            System.out.println("\nTarget vector: " + targetVector);
             R<SearchResults> searchRet = milvusClient.search(SearchParam.newBuilder()
                     .withCollectionName(COLLECTION_NAME)
                     .withMetricType(MetricType.IP)

+ 8 - 11
examples/src/main/java/io/milvus/v2/BinaryVectorExample.java

@@ -42,10 +42,10 @@ import java.util.*;
 
 public class BinaryVectorExample {
     private static final String COLLECTION_NAME = "java_sdk_example_binary_vector_v2";
-    private static final String ID_FIELD = "id";
+    private static final String ID_FIELD = "pk";
     private static final String VECTOR_FIELD = "vector";
 
-    private static final Integer VECTOR_DIM = 512;
+    private static final Integer VECTOR_DIM = 128;
 
 
     public static void main(String[] args) {
@@ -126,6 +126,8 @@ public class BinaryVectorExample {
             Random ran = new Random();
             int k = ran.nextInt(rowCount);
             ByteBuffer targetVector = vectors.get(k);
+            System.out.printf("\nANN search for vector ID=%d:\n", k);
+            CommonUtils.printBinaryVector(targetVector);
             Map<String,Object> params = new HashMap<>();
             params.put("nprobe",16);
             SearchResp searchResp = client.search(SearchReq.builder()
@@ -141,16 +143,11 @@ public class BinaryVectorExample {
             // Here we only input one vector to search, get the result of No.0 vector to check
             List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
             List<SearchResp.SearchResult> results = searchResults.get(0);
-            System.out.printf("The result of No.%d target vector:\n", i);
+            System.out.printf("The result of No.%d target vector, ID=%d:\n", i, k);
             for (SearchResp.SearchResult result : results) {
-                System.out.println(result.getEntity());
-                System.out.printf("ID: %d, Score: %f, Vector: ", result.getId(), result.getScore());
+                System.out.println(result);
                 ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
-                vector.rewind();
-                while (vector.hasRemaining()) {
-                    System.out.print(Integer.toBinaryString(vector.get()));
-                }
-                System.out.println();
+                CommonUtils.printBinaryVector(vector);
             }
 
             SearchResp.SearchResult firstResult = results.get(0);
@@ -165,7 +162,7 @@ public class BinaryVectorExample {
         int n = 99;
         QueryResp queryResp = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
-                .filter(String.format("id == %d", n))
+                .filter(String.format("%s == %d", ID_FIELD, n))
                 .outputFields(Collections.singletonList(VECTOR_FIELD))
                 .build());
 

+ 7 - 2
examples/src/main/java/io/milvus/v2/Float16VectorExample.java

@@ -162,9 +162,12 @@ public class Float16VectorExample {
             }
             Map<String, Object> entity = topResult.getEntity();
             ByteBuffer vectorBuf = (ByteBuffer) entity.get(vectorFieldName);
-            if (!vectorBuf.equals(targetVectors.get(i).getData())) {
+            ByteBuffer targetVectorBuf = (ByteBuffer)targetVectors.get(i).getData();
+            if (!vectorBuf.equals(targetVectorBuf)) {
                 throw new RuntimeException("The top1 output vector is incorrect");
             }
+            List<Float> decodedTargetVector = CommonUtils.decodeFloat16Vector(targetVectorBuf,
+                    BF16_VECTOR_FIELD.equals(vectorFieldName));
             // The method for converting float16 vector to float32 vector can be found in
             // CommonUtils.
             List<Float> decodedFpVector = CommonUtils.decodeFloat16Vector(vectorBuf,
@@ -172,7 +175,9 @@ public class Float16VectorExample {
             if (decodedFpVector.size() != VECTOR_DIM) {
                 throw new RuntimeException("The decoded vector dimension is incorrect");
             }
-            System.out.println(results.get(0));
+            System.out.println("\nTarget vector: " + decodedTargetVector);
+            System.out.println("Top0 result: " + topResult);
+            System.out.println("Top0 result vector: " + decodedFpVector);
         }
         System.out.println("Search result of " + vectorFieldName + " is correct");
     }

+ 1 - 1
examples/src/main/java/io/milvus/v2/FullTextSearchExample.java

@@ -38,7 +38,7 @@ public class FullTextSearchExample {
         List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
         for (List<SearchResp.SearchResult> results : searchResults) {
             for (SearchResp.SearchResult result : results) {
-                System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString());
+                System.out.println(result);
             }
         }
         System.out.println("=============================================================");

+ 1 - 1
examples/src/main/java/io/milvus/v2/GeneralExample.java

@@ -205,7 +205,7 @@ public class GeneralExample {
         for (List<SearchResp.SearchResult> results : searchResults) {
             System.out.println("Search result of No." + i++);
             for (SearchResp.SearchResult result : results) {
-                System.out.printf("ID: %s, Score: %f, %s\n", result.getId(), result.getScore(), result.getEntity().toString());
+                System.out.println(result);
             }
         }
     }

+ 1 - 1
examples/src/main/java/io/milvus/v2/HybridSearchExample.java

@@ -223,7 +223,7 @@ public class HybridSearchExample {
             System.out.printf("============= Search result of No.%d vector =============\n", i);
             List<SearchResp.SearchResult> results = searchResults.get(i);
             for (SearchResp.SearchResult result : results) {
-                System.out.printf("{id: %d, score: %f}%n", result.getId(), result.getScore());
+                System.out.println(result);
             }
         }
     }

+ 2 - 3
examples/src/main/java/io/milvus/v2/Int8VectorExample.java

@@ -136,10 +136,9 @@ public class Int8VectorExample {
             List<SearchResp.SearchResult> results = searchResults.get(0);
             System.out.printf("\nThe result of No.%d vector %s:\n", k, Arrays.toString(targetVector.array()));
             for (SearchResp.SearchResult result : results) {
-                System.out.printf("ID: %d, Score: %f, Vector: ", (long)result.getId(), result.getScore());
+                System.out.println(result);
                 ByteBuffer vector = (ByteBuffer) result.getEntity().get(VECTOR_FIELD);
-                System.out.print(Arrays.toString(vector.array()));
-                System.out.println();
+                System.out.println(Arrays.toString(vector.array()));
             }
 
             SearchResp.SearchResult firstResult = results.get(0);

+ 2 - 1
examples/src/main/java/io/milvus/v2/SparseVectorExample.java

@@ -119,6 +119,7 @@ public class SparseVectorExample {
             Random ran = new Random();
             int k = ran.nextInt(rowCount);
             SortedMap<Long, Float> targetVector = vectors.get(k);
+            System.out.println("\nTarget vector: " + targetVector);
             Map<String,Object> params = new HashMap<>();
             params.put("drop_ratio_search",0.2);
             SearchResp searchResp = client.search(SearchReq.builder()
@@ -136,7 +137,7 @@ public class SparseVectorExample {
             List<SearchResp.SearchResult> results = searchResults.get(0);
             System.out.printf("The result of No.%d target vector:\n", i);
             for (SearchResp.SearchResult result : results) {
-                System.out.println(result.getEntity());
+                System.out.println(result);
             }
 
             SearchResp.SearchResult firstResult = results.get(0);

+ 1 - 1
examples/src/main/java/io/milvus/v2/TextMatchExample.java

@@ -36,7 +36,7 @@ public class TextMatchExample {
         System.out.println("\nQuery with filter: " + filter);
         List<QueryResp.QueryResult> records = queryRet.getQueryResults();
         for (QueryResp.QueryResult record : records) {
-            System.out.println(record.getEntity());
+            System.out.println(record);
         }
         System.out.printf("%d items matched%n", records.size());
         System.out.println("=============================================================");

+ 12 - 12
sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java

@@ -27,6 +27,7 @@ import io.milvus.param.Constant;
 import io.milvus.response.basic.RowRecordWrapper;
 import lombok.Getter;
 import lombok.NonNull;
+import org.apache.commons.lang3.StringUtils;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -161,6 +162,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
 
         // set id and score
         IDs ids = results.getIds();
+        String pkName = results.getPrimaryFieldName();
         if (ids.hasIntId()) {
             LongArray longIDs = ids.getIntId();
             if (offset + k > longIDs.getDataCount()) {
@@ -168,7 +170,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             }
 
             for (int n = 0; n < k; ++n) {
-                idScores.add(new IDScore("", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
+                idScores.add(new IDScore(pkName, "", longIDs.getData((int)offset + n), results.getScores((int)offset + n)));
             }
         } else if (ids.hasStrId()) {
             StringArray strIDs = ids.getStrId();
@@ -177,7 +179,7 @@ public class SearchResultsWrapper extends RowRecordWrapper {
             }
 
             for (int n = 0; n < k; ++n) {
-                idScores.add(new IDScore(strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
+                idScores.add(new IDScore(pkName, strIDs.getData((int)offset + n), 0, results.getScores((int)offset + n)));
             }
         } else {
             // in v2.3.3, return an empty list instead of throwing exception
@@ -272,12 +274,14 @@ public class SearchResultsWrapper extends RowRecordWrapper {
      */
     @Getter
     public static final class IDScore {
+        private final String primaryKey;
         private final String strID;
         private final long longID;
         private final float score;
         Map<String, Object> fieldValues = new HashMap<>();
 
-        public IDScore(String strID, long longID, float score) {
+        public IDScore(String primaryKey, String strID, long longID, float score) {
+            this.primaryKey = primaryKey;
             this.strID = strID;
             this.longID = longID;
             this.score = score;
@@ -333,16 +337,12 @@ public class SearchResultsWrapper extends RowRecordWrapper {
 
         @Override
         public String toString() {
-            List<String> pairs = new ArrayList<>();
-            fieldValues.forEach((keyName, fieldValue) -> {
-                pairs.add(keyName + ":" + fieldValue);
-            });
-
-            if (strID.isEmpty()) {
-                return "(ID: " + getLongID() + " Score: " + getScore() + " OutputFields: " + pairs + ")";
-            } else {
-                return "(ID: '" + getStrID() + "' Score: " + getScore()+ " OutputFields: " + pairs + ")";
+            Object id = strID;
+            if (StringUtils.isEmpty(strID)) {
+                id = longID;
             }
+
+            return "{" + getPrimaryKey() + ": " + id + ", Score: " + getScore() + ", OutputFields: " + fieldValues + "}";
         }
     }
 }

+ 7 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java

@@ -45,5 +45,12 @@ public class SearchResp {
         private Map<String, Object> entity = new HashMap<>();
         private Float score;
         private Object id;
+        @Builder.Default
+        private String primaryKey = "id";
+
+        @Override
+        public String toString() {
+            return "{" + getPrimaryKey() + ": " + getId() + ", Score: " + getScore() + ", OutputFields: " + entity + "}";
+        }
     }
 }

+ 1 - 0
sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java

@@ -75,6 +75,7 @@ public class ConvertUtils {
             searchResults.add(searchResultsWrapper.getIDScore(i).stream().map(idScore -> SearchResp.SearchResult.builder()
                     .entity(idScore.getFieldValues())
                     .score(idScore.getScore())
+                    .primaryKey(idScore.getPrimaryKey())
                     .id(idScore.getStrID().isEmpty() ? idScore.getLongID() : idScore.getStrID())
                     .build()).collect(Collectors.toList()));
         }