Просмотр исходного кода

Add group-by example (#1662)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 месяцев назад
Родитель
Сommit
0a4e20595c
1 измененных файлов с 238 добавлено и 0 удалено
  1. 238 0
      examples/src/main/java/io/milvus/v2/GroupByExample.java

+ 238 - 0
examples/src/main/java/io/milvus/v2/GroupByExample.java

@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v2;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonObject;
+import io.milvus.common.clientenum.FunctionType;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.DataType;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.AddFieldReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq.Function;
+import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.vector.request.*;
+import io.milvus.v2.service.vector.request.data.BaseVector;
+import io.milvus.v2.service.vector.request.data.EmbeddedText;
+import io.milvus.v2.service.vector.request.data.FloatVec;
+import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
+import io.milvus.v2.service.vector.response.QueryResp;
+import io.milvus.v2.service.vector.response.SearchResp;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+public class GroupByExample {
+    private static final String COLLECTION_NAME = "java_sdk_example_group_by_v2";
+    private static final String ID_FIELD = "id";
+    private static final String DENSE_FIELD = "dense";
+    private static final String SPARSE_FIELD = "sparse";
+    private static final String TEXT_FIELD = "text";
+    private static final String DOCID_FIELD = "docId";
+    private static final int DIM = 5;
+
+    private static void searchGroupBy(MilvusClientV2 client, String groupField, long limit, int groupSize, boolean strictGroup) {
+        System.out.println("\n====================================================================================");
+        BaseVector targetVector = new FloatVec(new float[]{0.145292f, 0.914725f, 0.796505f, 0.700925f, 0.560520f});
+        SearchResp searchResp = client.search(SearchReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(Collections.singletonList(targetVector))
+                .annsField(DENSE_FIELD)
+                .limit(limit)
+                .outputFields(Arrays.asList(DOCID_FIELD, TEXT_FIELD))
+                .groupByFieldName(groupField)
+                .groupSize(groupSize)
+                .strictGroupSize(strictGroup)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build());
+        System.out.printf("Search with group by field: %s, group size: %d, strict: %b, limit: %d%n",
+                groupField, groupSize, strictGroup, limit);
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        for (List<SearchResp.SearchResult> results : searchResults) {
+            for (SearchResp.SearchResult result : results) {
+                System.out.println(result);
+            }
+        }
+    }
+
+    private static void hybridSearchGroupBy(MilvusClientV2 client, String groupField, long limit, int groupSize, boolean strictGroup) {
+        System.out.println("\n====================================================================================");
+        BaseVector targetVector = new FloatVec(new float[]{0.145292f, 0.914725f, 0.796505f, 0.700925f, 0.560520f});
+        BaseVector targetText = new EmbeddedText("Milvus vector database");
+        List<AnnSearchReq> searchRequests = new ArrayList<>();
+        searchRequests.add(AnnSearchReq.builder()
+                .vectorFieldName(DENSE_FIELD)
+                .vectors(Collections.singletonList(targetVector))
+                .limit(limit * 2)
+                .build());
+        searchRequests.add(AnnSearchReq.builder()
+                .vectorFieldName(SPARSE_FIELD)
+                .vectors(Collections.singletonList(targetText))
+                .limit(limit * 2)
+                .build());
+
+        SearchResp searchResp = client.hybridSearch(HybridSearchReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .searchRequests(searchRequests)
+                .functionScore(FunctionScore.builder()
+                        .addFunction(WeightedRanker.builder().weights(Arrays.asList(0.5f, 0.5f)).build())
+                        .build())
+                .limit(limit)
+                .outFields(Arrays.asList(DOCID_FIELD, TEXT_FIELD))
+                .groupByFieldName(groupField)
+                .groupSize(groupSize)
+                .strictGroupSize(strictGroup)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build());
+        System.out.printf("HybridSearch with group by field: %s, group size: %d, strict: %b, limit: %d%n",
+                groupField, groupSize, strictGroup, limit);
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        for (List<SearchResp.SearchResult> results : searchResults) {
+            for (SearchResp.SearchResult result : results) {
+                System.out.println(result);
+            }
+        }
+    }
+
+    public static void main(String[] args) {
+        ConnectConfig config = ConnectConfig.builder()
+                .uri("http://localhost:19530")
+                .build();
+        MilvusClientV2 client = new MilvusClientV2(config);
+
+        // Drop collection if exists
+        client.dropCollection(DropCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+
+        // Create collection
+        CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        schema.addField(AddFieldReq.builder()
+                .fieldName(ID_FIELD)
+                .dataType(DataType.Int64)
+                .isPrimaryKey(true)
+                .autoID(false)
+                .build());
+        schema.addField(AddFieldReq.builder()
+                .fieldName(DENSE_FIELD)
+                .dataType(DataType.FloatVector)
+                .dimension(DIM)
+                .build());
+        schema.addField(AddFieldReq.builder()
+                .fieldName(SPARSE_FIELD)
+                .dataType(DataType.SparseFloatVector)
+                .build());
+        schema.addField(AddFieldReq.builder()
+                .fieldName(TEXT_FIELD)
+                .dataType(DataType.VarChar)
+                .maxLength(65535)
+                .enableAnalyzer(true) // must enable this if you use Function
+                .build());
+        schema.addField(AddFieldReq.builder()
+                .fieldName(DOCID_FIELD)
+                .dataType(DataType.Int32)
+                .build());
+
+        // With this function, milvus will convert the strings of "text" field to sparse vectors of "vector" field
+        // by built-in tokenizer and analyzer
+        // Read the link for more info: https://milvus.io/docs/full-text-search.md
+        schema.addFunction(Function.builder()
+                .functionType(FunctionType.BM25)
+                .name("function_bm25")
+                .inputFieldNames(Collections.singletonList(TEXT_FIELD))
+                .outputFieldNames(Collections.singletonList(SPARSE_FIELD))
+                .build());
+
+        List<IndexParam> indexes = new ArrayList<>();
+        indexes.add(IndexParam.builder()
+                .fieldName(DENSE_FIELD)
+                .indexType(IndexParam.IndexType.FLAT)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+        indexes.add(IndexParam.builder()
+                .fieldName(SPARSE_FIELD)
+                .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
+                .metricType(IndexParam.MetricType.BM25) // to use full text search, metric type must be "BM25"
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .collectionSchema(schema)
+                .indexParams(indexes)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build();
+        client.createCollection(requestCreate);
+        System.out.println("Collection created");
+
+        // Insert rows
+        Gson gson = new Gson();
+        List<JsonObject> rows = Arrays.asList(
+                gson.fromJson("{\"id\": 0, \"text\": \"Milvus is an open-source vector database\", \"dense\": [0.358037, -0.602349, 0.184140, -0.262862, 0.902943], \"docId\": 1}", JsonObject.class),
+                gson.fromJson("{\"id\": 1, \"text\": \"AI applications help people better life\", \"dense\": [0.198868, 0.060235, 0.697696, 0.261447, 0.838729], \"docId\": 5}", JsonObject.class),
+                gson.fromJson("{\"id\": 2, \"text\": \"Will the electric car replace gas-powered car?\", \"dense\": [0.437421, -0.559750, 0.645788, 0.789405, 0.207857], \"docId\": 2}", JsonObject.class),
+                gson.fromJson("{\"id\": 3, \"text\": \"LangChain is a composable framework to build with LLMs. Milvus is integrated into LangChain.\", \"dense\": [0.317200, 0.971904, -0.369811, 0.120690, -0.144627], \"docId\": 4}", JsonObject.class),
+                gson.fromJson("{\"id\": 4, \"text\": \"RAG is the process of optimizing the output of a large language model\", \"dense\": [0.837197, -0.015764, -0.310629, -0.562666, -0.898494], \"docId\": 1}", JsonObject.class),
+                gson.fromJson("{\"id\": 5, \"text\": \"Newton is one of the greatest scientist of human history\", \"dense\": [-0.33445, -0.256713, 0.898753, 0.940299, 0.537806], \"docId\": 2}", JsonObject.class),
+                gson.fromJson("{\"id\": 6, \"text\": \"Metric type L2 is Euclidean distance\", \"dense\": [0.395247, 0.400025, -0.589050, -0.865050, -0.6140360], \"docId\": 5}", JsonObject.class),
+                gson.fromJson("{\"id\": 7, \"text\": \"Embeddings represent real-world objects, like words, images, or videos, in a form that computers can process.\", \"dense\": [0.571828, 0.240703, -0.373791, -0.067269, -0.6980531], \"docId\": 3}", JsonObject.class)
+        );
+
+        client.insert(InsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(rows)
+                .build());
+
+        // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
+        QueryResp countR = client.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .outputFields(Collections.singletonList("count(*)"))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+        System.out.printf("%d rows in collection\n", (long) countR.getQueryResults().get(0).getEntity().get("count(*)"));
+
+        // search without group by, limit 3
+        searchGroupBy(client, null, 3, 0, false);
+        // search group by docId, limit 3, group size 1, strict false
+        searchGroupBy(client, DOCID_FIELD, 3, 1, false);
+        // search group by docId, limit 3, group size 3, strict false
+        searchGroupBy(client, DOCID_FIELD, 3, 2, false);
+        // search group by docId, limit 3, group size 3, strict true
+        searchGroupBy(client, DOCID_FIELD, 3, 2, true);
+        // search group by docId, limit 4, group size 3, strict false
+        searchGroupBy(client, DOCID_FIELD, 4, 3, false);
+        // search group by docId, limit 4, group size 3, strict false
+        searchGroupBy(client, DOCID_FIELD, 4, 3, true);
+
+        // hybrid search without group by, limit 3
+        hybridSearchGroupBy(client, null, 3, 0, false);
+        // hybrid search group by docId, limit 3, group size 2, strict false
+        hybridSearchGroupBy(client, DOCID_FIELD, 3, 2, false);
+        // hybrid search group by docId, limit 3, group size 2, strict true
+        hybridSearchGroupBy(client, DOCID_FIELD, 3, 2, true);
+
+        client.close();
+    }
+}