Browse Source

Refine example code (#850)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 year ago
parent
commit
87ba0076f9

+ 60 - 76
examples/main/java/io/milvus/HybridSearchExample.java

@@ -19,7 +19,7 @@
 package io.milvus;
 
 import com.alibaba.fastjson.JSONObject;
-import com.google.common.collect.Lists;
+import io.milvus.client.MilvusClient;
 import io.milvus.client.MilvusServiceClient;
 import io.milvus.common.clientenum.ConsistencyLevelEnum;
 import io.milvus.grpc.DataType;
@@ -40,8 +40,15 @@ import java.util.List;
 
 
 public class HybridSearchExample {
-    private static final String HOST = "localhost";
-    private static final int HOST_PORT = 19530;
+    private static final MilvusClient milvusClient;
+
+    static {
+        milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
+                .withHost("localhost")
+                .withPort(19530)
+                .build());
+    }
+
     private static final String COLLECTION_NAME = "java_sdk_example_hybrid_search";
     private static final String ID_FIELD = "ID";
 
@@ -60,19 +67,16 @@ public class HybridSearchExample {
     private static final String SPARSE_VECTOR_FIELD = "sparse_vector";
     private static final MetricType SPARSE_VECTOR_METRIC = MetricType.IP;
 
-    private static void createCollection() {
-        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
-                .withHost(HOST)
-                .withPort(HOST_PORT)
+    private void createCollection() {
+        R<RpcStatus> resp = milvusClient.dropCollection(DropCollectionParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
                 .build());
+        CommonUtils.handleResponseStatus(resp);
 
         // Define fields
-        // There is a configuration in milvus.yaml to define the max vector fields in a collection
-        // proxy.maxVectorFieldNum: 4
+        // Note: There is a configuration in milvus.yaml to define the max vector fields in a collection
+        //     proxy.maxVectorFieldNum: 4
         // By default, the max vector fields number is 4
-        // In v2.4.0 there is a known bug that sparse vectors and float16 vectors in one collection
-        // will crash milvus, so we comment out float16 vector field here.
-        // https://github.com/milvus-io/milvus/issues/31988
         List<FieldType> fieldsSchema = Arrays.asList(
                 FieldType.newBuilder()
                         .withName(ID_FIELD)
@@ -90,11 +94,11 @@ public class HybridSearchExample {
                         .withDataType(DataType.BinaryVector)
                         .withDimension(BINARY_VECTOR_DIM)
                         .build(),
-//                FieldType.newBuilder()
-//                        .withName(FLOAT16_VECTOR_FIELD)
-//                        .withDataType(DataType.Float16Vector)
-//                        .withDimension(FLOAT16_VECTOR_DIM)
-//                        .build(),
+                FieldType.newBuilder()
+                        .withName(FLOAT16_VECTOR_FIELD)
+                        .withDataType(DataType.Float16Vector)
+                        .withDimension(FLOAT16_VECTOR_DIM)
+                        .build(),
                 FieldType.newBuilder()
                         .withName(SPARSE_VECTOR_FIELD)
                         .withDataType(DataType.SparseFloatVector)
@@ -102,7 +106,7 @@ public class HybridSearchExample {
         );
 
         // Create the collection with multi vector fields
-        R<RpcStatus> resp = milvusClient.createCollection(CreateCollectionParam.newBuilder()
+        resp = milvusClient.createCollection(CreateCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .withSchema(CollectionSchemaParam.newBuilder().withFieldTypes(fieldsSchema).build())
                 .build());
@@ -126,14 +130,14 @@ public class HybridSearchExample {
                 .build());
         CommonUtils.handleResponseStatus(resp);
 
-//        resp = milvusClient.createIndex(CreateIndexParam.newBuilder()
-//                .withCollectionName(COLLECTION_NAME)
-//                .withFieldName(FLOAT16_VECTOR_FIELD)
-//                .withIndexType(IndexType.IVF_FLAT)
-//                .withExtraParam("{\"nlist\":128}")
-//                .withMetricType(FLOAT16_VECTOR_METRIC)
-//                .build());
-//        CommonUtils.handleResponseStatus(resp);
+        resp = milvusClient.createIndex(CreateIndexParam.newBuilder()
+                .withCollectionName(COLLECTION_NAME)
+                .withFieldName(FLOAT16_VECTOR_FIELD)
+                .withIndexType(IndexType.HNSW)
+                .withExtraParam("{\"M\":16,\"efConstruction\":64}")
+                .withMetricType(FLOAT16_VECTOR_METRIC)
+                .build());
+        CommonUtils.handleResponseStatus(resp);
 
         resp = milvusClient.createIndex(CreateIndexParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
@@ -150,16 +154,9 @@ public class HybridSearchExample {
                 .build());
 
         System.out.println("Collection created");
-
-        milvusClient.close();
     }
 
-    private static void insertData() {
-        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
-                .withHost(HOST)
-                .withPort(HOST_PORT)
-                .build());
-
+    private void insertData() {
         long idCount = 0;
         int rowCount = 10000;
         // Insert entities by rows
@@ -169,7 +166,7 @@ public class HybridSearchExample {
             row.put(ID_FIELD, idCount++);
             row.put(FLOAT_VECTOR_FIELD, CommonUtils.generateFloatVector(FLOAT_VECTOR_DIM));
             row.put(BINARY_VECTOR_FIELD, CommonUtils.generateBinaryVector(BINARY_VECTOR_DIM));
-//            row.put(FLOAT16_VECTOR_FIELD, CommonUtils.generateFloat16Vector(FLOAT16_VECTOR_DIM, false));
+            row.put(FLOAT16_VECTOR_FIELD, CommonUtils.generateFloat16Vector(FLOAT16_VECTOR_DIM, false));
             row.put(SPARSE_VECTOR_FIELD, CommonUtils.generateSparseVector());
             rows.add(row);
         }
@@ -194,8 +191,8 @@ public class HybridSearchExample {
                 CommonUtils.generateFloatVectors(FLOAT_VECTOR_DIM, rowCount)));
         fieldsInsert.add(new InsertParam.Field(BINARY_VECTOR_FIELD,
                 CommonUtils.generateBinaryVectors(BINARY_VECTOR_DIM, rowCount)));
-//        fieldsInsert.add(new InsertParam.Field(FLOAT16_VECTOR_FIELD,
-//                CommonUtils.generateFloat16Vectors(FLOAT16_VECTOR_DIM, rowCount, false)));
+        fieldsInsert.add(new InsertParam.Field(FLOAT16_VECTOR_FIELD,
+                CommonUtils.generateFloat16Vectors(FLOAT16_VECTOR_DIM, rowCount, false)));
         fieldsInsert.add(new InsertParam.Field(SPARSE_VECTOR_FIELD,
                 CommonUtils.generateSparseVectors(rowCount)));
 
@@ -206,16 +203,9 @@ public class HybridSearchExample {
         CommonUtils.handleResponseStatus(resp);
 
         System.out.printf("%d entities inserted by columns\n", rowCount);
-
-        milvusClient.close();
     }
 
-    private static void hybridSearch() {
-        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
-                .withHost(HOST)
-                .withPort(HOST_PORT)
-                .build());
-
+    private void hybridSearch() {
         // Get the row count
         R<GetCollectionStatisticsResponse> resp = milvusClient.getCollectionStatistics(GetCollectionStatisticsParam
                 .newBuilder()
@@ -228,10 +218,10 @@ public class HybridSearchExample {
         System.out.println("Collection row count: " + stat.getRowCount());
 
         // Search on multiple vector fields
-        // Note that only allow one vector for each sub request
+        int NQ = 2;
         AnnSearchParam req1 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(FLOAT_VECTOR_FIELD)
-                .withFloatVectors(CommonUtils.generateFloatVectors(FLOAT_VECTOR_DIM, 1))
+                .withFloatVectors(CommonUtils.generateFloatVectors(FLOAT_VECTOR_DIM, NQ))
                 .withMetricType(FLOAT_VECTOR_METRIC)
                 .withParams("{\"nprobe\": 32}")
                 .withTopK(10)
@@ -239,22 +229,22 @@ public class HybridSearchExample {
 
         AnnSearchParam req2 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(BINARY_VECTOR_FIELD)
-                .withBinaryVectors(CommonUtils.generateBinaryVectors(BINARY_VECTOR_DIM, 1))
+                .withBinaryVectors(CommonUtils.generateBinaryVectors(BINARY_VECTOR_DIM, NQ))
                 .withMetricType(BINARY_VECTOR_METRIC)
                 .withTopK(15)
                 .build();
 
-//        AnnSearchParam req3 = AnnSearchParam.newBuilder()
-//                .withVectorFieldName(FLOAT16_VECTOR_FIELD)
-//                .withFloat16Vectors(CommonUtils.generateFloat16Vectors(FLOAT16_VECTOR_DIM, 1, false))
-//                .withMetricType(FLOAT16_VECTOR_METRIC)
-//                .withParams("{\"es\":200}")
-//                .withTopK(20)
-//                .build();
+        AnnSearchParam req3 = AnnSearchParam.newBuilder()
+                .withVectorFieldName(FLOAT16_VECTOR_FIELD)
+                .withFloat16Vectors(CommonUtils.generateFloat16Vectors(FLOAT16_VECTOR_DIM, NQ, false))
+                .withMetricType(FLOAT16_VECTOR_METRIC)
+                .withParams("{\"ef\":64}")
+                .withTopK(20)
+                .build();
 
         AnnSearchParam req4 = AnnSearchParam.newBuilder()
                 .withVectorFieldName(SPARSE_VECTOR_FIELD)
-                .withSparseFloatVectors(CommonUtils.generateSparseVectors(1))
+                .withSparseFloatVectors(CommonUtils.generateSparseVectors(NQ))
                 .withMetricType(SPARSE_VECTOR_METRIC)
                 .withParams("{\"drop_ratio_search\":0.2}")
                 .withTopK(20)
@@ -264,11 +254,11 @@ public class HybridSearchExample {
                 .withCollectionName(COLLECTION_NAME)
                 .addOutField(FLOAT_VECTOR_FIELD)
                 .addOutField(BINARY_VECTOR_FIELD)
-//                .addOutField(FLOAT16_VECTOR_FIELD)
+                .addOutField(FLOAT16_VECTOR_FIELD)
                 .addOutField(SPARSE_VECTOR_FIELD)
                 .addSearchRequest(req1)
                 .addSearchRequest(req2)
-//                .addSearchRequest(req3)
+                .addSearchRequest(req3)
                 .addSearchRequest(req4)
                 .withTopK(5)
                 .withConsistencyLevel(ConsistencyLevelEnum.STRONG)
@@ -282,35 +272,29 @@ public class HybridSearchExample {
 
         // Print search result
         SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults());
-        List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
-        for (int i = 0; i < scores.size(); ++i) {
-            System.out.println(scores.get(i));
+        for (int k = 0; k < NQ; k++) {
+            System.out.printf("============= Search result of No.%d vector =============\n", k);
+            List<SearchResultsWrapper.IDScore> scores = results.getIDScore(0);
+            for (int i = 0; i < scores.size(); ++i) {
+                System.out.println(scores.get(i));
+            }
         }
-
-
-        milvusClient.close();
     }
 
-    private static void dropCollection() {
-        MilvusServiceClient milvusClient = new MilvusServiceClient(ConnectParam.newBuilder()
-                .withHost(HOST)
-                .withPort(HOST_PORT)
-                .build());
-
+    private void dropCollection() {
         R<RpcStatus> resp = milvusClient.dropCollection(DropCollectionParam.newBuilder()
                 .withCollectionName(COLLECTION_NAME)
                 .build());
         CommonUtils.handleResponseStatus(resp);
 
         System.out.println("Collection dropped");
-
-        milvusClient.close();
     }
 
     public static void main(String[] args) {
-        createCollection();
-        insertData();
-        hybridSearch();
-        dropCollection();
+        HybridSearchExample example = new HybridSearchExample();
+        example.createCollection();
+        example.insertData();
+        example.hybridSearch();
+        example.dropCollection();
     }
 }

+ 7 - 7
src/main/java/io/milvus/param/dml/AnnSearchParam.java

@@ -76,8 +76,8 @@ public class AnnSearchParam {
         private String params = "{}";
 
         // plType is used to distinct vector type
-        // for Float16Vector/BFloat16Vector and BinaryVector, user input ByteBuffer
-        // the server cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
+        // for Float16Vector/BFloat16Vector and BinaryVector, user inputs ByteBuffer
+        // the sdk cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
         private PlaceholderType plType = PlaceholderType.None;
 
         Builder() {
@@ -137,7 +137,7 @@ public class AnnSearchParam {
         public Builder withFloatVectors(@NonNull List<List<Float>> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.FloatVector;
+            this.plType = PlaceholderType.FloatVector;
             return this;
         }
 
@@ -150,7 +150,7 @@ public class AnnSearchParam {
         public Builder withBinaryVectors(@NonNull List<ByteBuffer> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.BinaryVector;
+            this.plType = PlaceholderType.BinaryVector;
             return this;
         }
 
@@ -163,7 +163,7 @@ public class AnnSearchParam {
         public Builder withFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.Float16Vector;
+            this.plType = PlaceholderType.Float16Vector;
             return this;
         }
 
@@ -176,7 +176,7 @@ public class AnnSearchParam {
         public Builder withBFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.BFloat16Vector;
+            this.plType = PlaceholderType.BFloat16Vector;
             return this;
         }
 
@@ -189,7 +189,7 @@ public class AnnSearchParam {
         public Builder withSparseFloatVectors(@NonNull List<SortedMap<Long, Float>> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.SparseFloatVector;
+            this.plType = PlaceholderType.SparseFloatVector;
             return this;
         }
 

+ 10 - 10
src/main/java/io/milvus/param/dml/SearchParam.java

@@ -111,8 +111,8 @@ public class SearchParam {
         private String groupByFieldName;
 
         // plType is used to distinct vector type
-        // for Float16Vector/BFloat16Vector and BinaryVector, user input ByteBuffer
-        // the server cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
+        // for Float16Vector/BFloat16Vector and BinaryVector, user inputs ByteBuffer
+        // the sdk cannot distinct a ByteBuffer is a BinarVector or a Float16Vector
         private PlaceholderType plType = PlaceholderType.None;
 
         Builder() {
@@ -246,14 +246,14 @@ public class SearchParam {
 
         /**
          * Sets the target vectors.
-         * Note: Deprecated in v2.4.0, for the reason that the server cannot know a ByteBuffer
+         * Note: Deprecated in v2.4.0, for the reason that the sdk cannot know a ByteBuffer
          *       is a BinarVector or Float16Vector/BFloat16Vector.
-         *       Replaced by withFloatVectors/withBinaryVectors/withFloat16Vectors/withBFloat16Vectors.
+         *       Replaced by withFloatVectors/withBinaryVectors/withFloat16Vectors/withBFloat16Vectors/withSparseFloatVectors.
          *       It still works for FloatVector/BinarVector/SparseVector, don't use it for Float16Vector/BFloat16Vector.
          *
          * @param vectors list of target vectors:
          *                if vector type is FloatVector, vectors is List of List Float;
-         *                if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
+         *                if vector type is BinaryVector, vectors is List of ByteBuffer;
          *                if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
          * @return <code>Builder</code>
          */
@@ -273,7 +273,7 @@ public class SearchParam {
         public Builder withFloatVectors(@NonNull List<List<Float>> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.FloatVector;
+            this.plType = PlaceholderType.FloatVector;
             return this;
         }
 
@@ -286,7 +286,7 @@ public class SearchParam {
         public Builder withBinaryVectors(@NonNull List<ByteBuffer> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.BinaryVector;
+            this.plType = PlaceholderType.BinaryVector;
             return this;
         }
 
@@ -299,7 +299,7 @@ public class SearchParam {
         public Builder withFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.Float16Vector;
+            this.plType = PlaceholderType.Float16Vector;
             return this;
         }
 
@@ -312,7 +312,7 @@ public class SearchParam {
         public Builder withBFloat16Vectors(@NonNull List<ByteBuffer> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.BFloat16Vector;
+            this.plType = PlaceholderType.BFloat16Vector;
             return this;
         }
 
@@ -325,7 +325,7 @@ public class SearchParam {
         public Builder withSparseFloatVectors(@NonNull List<SortedMap<Long, Float>> vectors) {
             this.vectors = vectors;
             this.NQ = (long) vectors.size();
-            plType = PlaceholderType.SparseFloatVector;
+            this.plType = PlaceholderType.SparseFloatVector;
             return this;
         }
 

+ 1 - 3
src/main/java/io/milvus/param/highlevel/dml/SearchSimpleParam.java

@@ -120,9 +120,7 @@ public class SearchSimpleParam {
         /**
          * Sets the target vectors.
          *
-         * @param vectors list of target vectors:
-         *               if vector type is FloatVector, vectors is List of List Float;
-         *               if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
+         * @param vectors list of target vectors: List of List Float;
          * @return <code>Builder</code>
          */
         public Builder withVectors(@NonNull List<?> vectors) {