Просмотр исходного кода

Support NGRAM and BoostRanker (#1640)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 3 месяцев назад
Родитель
Сommit
362cc93f7a

+ 11 - 11
examples/src/main/java/io/milvus/v2/HybridSearchExample.java

@@ -73,7 +73,7 @@ public class HybridSearchExample {
     private static final String SPARSE_VECTOR_FIELD = "sparse_vector";
     private static final IndexParam.MetricType SPARSE_VECTOR_METRIC = IndexParam.MetricType.IP;
 
-    private void createCollection() {
+    private static void createCollection() {
         client.dropCollection(DropCollectionReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .build());
@@ -147,7 +147,7 @@ public class HybridSearchExample {
         System.out.println("Collection created");
     }
 
-    private void insertData() {
+    private static void insertData() {
         long idCount = 0;
         int rowCount = 10000;
         // Insert entities by rows
@@ -167,11 +167,10 @@ public class HybridSearchExample {
                 .collectionName(COLLECTION_NAME)
                 .data(rows)
                 .build());
-
-        System.out.printf("%d entities inserted by rows\n", rowCount);
+        printRowCount();
     }
 
-    private void hybridSearch() {
+    private static void printRowCount() {
         // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
         QueryResp countR = client.query(QueryReq.builder()
                 .collectionName(COLLECTION_NAME)
@@ -179,7 +178,9 @@ public class HybridSearchExample {
                 .consistencyLevel(ConsistencyLevel.STRONG)
                 .build());
         System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
+    }
 
+    private static void hybridSearch() {
         // Search on multiple vector fields
         int NQ = 2;
         List<BaseVector> floatVectors = new ArrayList<>();
@@ -229,7 +230,7 @@ public class HybridSearchExample {
         }
     }
 
-    private void dropCollection() {
+    private static void dropCollection() {
         client.dropCollection(DropCollectionReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .build());
@@ -237,11 +238,10 @@ public class HybridSearchExample {
     }
 
     public static void main(String[] args) {
-        io.milvus.v2.HybridSearchExample example = new io.milvus.v2.HybridSearchExample();
-        example.createCollection();
-        example.insertData();
-        example.hybridSearch();
-        example.dropCollection();
+        createCollection();
+        insertData();
+        hybridSearch();
+        dropCollection();
 
         client.close();
     }

+ 28 - 1
examples/src/main/java/io/milvus/v2/JsonFieldExample.java

@@ -98,6 +98,32 @@ public class JsonFieldExample {
                 .metricType(IndexParam.MetricType.COSINE)
                 .build());
 
+        // Create INVERTED index for a specific entry of JSON field
+        // Index for JSON field is supported from milvus v2.5.7 and fully supported in v2.5.13+
+        // Read the doc for more info: https://milvus.io/docs/json-indexing.md
+        Map<String,Object> p1 = new HashMap<>();
+        p1.put("json_path", "metadata[\"flags\"]");
+        p1.put("json_cast_type", "array_double");
+        indexes.add(IndexParam.builder()
+                .fieldName(JSON_FIELD)
+                .indexType(IndexParam.IndexType.INVERTED)
+                .extraParams(p1)
+                .build());
+
+        // Create NGRAM index for a specific entry of JSON field
+        // NGRAM index for JSON field is supported from milvus v2.6.2
+        // Read the doc for more info: https://milvus.io/docs/ngram.md
+        Map<String,Object> p2 = new HashMap<>();
+        p2.put("json_path","metadata[\"path\"]");
+        p2.put("json_cast_type", "varchar");
+        p2.put("min_gram", 3);
+        p2.put("max_gram", 5);
+        indexes.add(IndexParam.builder()
+                .fieldName(JSON_FIELD)
+                .indexType(IndexParam.IndexType.NGRAM)
+                .extraParams(p2)
+                .build());
+
         CreateCollectionReq requestCreate = CreateCollectionReq.builder()
                 .collectionName(COLLECTION_NAME)
                 .collectionSchema(collectionSchema)
@@ -121,7 +147,7 @@ public class JsonFieldExample {
             // Note: for JSON field, always construct a real JsonObject
             // don't use row.addProperty(JSON_FIELD, strContent) since the value is treated as a string, not a JsonObject
             JsonObject metadata = new JsonObject();
-            metadata.addProperty("path", String.format("\\root/abc/path_%d", i));
+            metadata.addProperty("path", String.format("\\root/abc_%d/path_%d", i, i));
             metadata.addProperty("size", i);
             if (i%7 == 0) {
                 metadata.addProperty("special", true);
@@ -197,6 +223,7 @@ public class JsonFieldExample {
         queryWithExpr(client, "JSON_CONTAINS(metadata[\"flags\"], 9)");
         queryWithExpr(client, "JSON_CONTAINS_ANY(metadata[\"flags\"], [8, 9, 10])");
         queryWithExpr(client, "JSON_CONTAINS_ALL(metadata[\"flags\"], [8, 9, 10])");
+        queryWithExpr(client, "metadata[\"path\"] LIKE \"%c_5%\"");
         queryWithExpr(client, "dynamic1 < 2.0");
 
         client.close();

+ 267 - 0
examples/src/main/java/io/milvus/v2/RankerExample.java

@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v2;
+
+import com.google.gson.JsonObject;
+import io.milvus.common.clientenum.FunctionType;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.DataType;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.AddFieldReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.vector.request.FunctionScore;
+import io.milvus.v2.service.vector.request.InsertReq;
+import io.milvus.v2.service.vector.request.QueryReq;
+import io.milvus.v2.service.vector.request.SearchReq;
+import io.milvus.v2.service.vector.request.data.EmbeddedText;
+import io.milvus.v2.service.vector.request.ranker.BoostRanker;
+import io.milvus.v2.service.vector.request.ranker.DecayRanker;
+import io.milvus.v2.service.vector.response.QueryResp;
+import io.milvus.v2.service.vector.response.SearchResp;
+
+import java.util.*;
+
+public class RankerExample {
+    private static final MilvusClientV2 client;
+
+    static {
+        ConnectConfig config = ConnectConfig.builder()
+                .uri("http://localhost:19530")
+                .build();
+        client = new MilvusClientV2(config);
+    }
+
+    private static final String COLLECTION_NAME = "java_sdk_example_ranker_v2";
+    private static final String NAME_FIELD = "name";
+    private static final String BIRTH_YEAR_FIELD = "birth_year";
+    private static final String LIFESPAN_FIELD = "lifespan";
+    private static final String SPARSE_VECTOR_FIELD = "sparse_vector";
+
+    private static class Person {
+        public String name;
+        public int fromYear;
+        public int toYear;
+        public Person(String name, int from, int to) {
+            this.name = name;
+            this.fromYear = from;
+            this.toYear = to;
+        }
+    }
+
+    private static List<Person> genData() {
+        List<Person> persons = new ArrayList<>();
+        persons.add(new Person("Isaac Newton", 1643, 1727));
+        persons.add(new Person("Albert Einstein", 1879, 1955));
+        persons.add(new Person("Marie Curie", 1867, 1934));
+        persons.add(new Person("Charles Darwin", 1809, 1882));
+        persons.add(new Person("Galileo Galilei", 1564, 1642));
+        persons.add(new Person("Nikola Tesla", 1856, 1943));
+        persons.add(new Person("James Clerk Maxwell", 1831, 1879));
+        persons.add(new Person("Thomas Edison", 1847, 1931));
+        persons.add(new Person("Alexander Fleming", 1881, 1955));
+        persons.add(new Person("Louis Pasteur", 1822, 1895));
+        persons.add(new Person("Werner Heisenberg", 1901, 1976));
+        persons.add(new Person("Stephen Hawking", 1942, 2018));
+        persons.add(new Person("Dmitri Mendeleev", 1834, 1907));
+        persons.add(new Person("Max Planck", 1858, 1947));
+        persons.add(new Person("Niels Bohr", 1885, 1962));
+        persons.add(new Person("Richard Feynman", 1918, 1988));
+        persons.add(new Person("Carl Sagan", 1934, 1996));
+        persons.add(new Person("Francis Crick", 1916, 2004));
+        persons.add(new Person("Rosalind Franklin", 1920, 1958));
+        persons.add(new Person("Edwin Hubble", 1889, 1953));
+        persons.add(new Person("Linus Pauling", 1901, 1994));
+        persons.add(new Person("Alan Turing", 1912, 1954));
+        persons.add(new Person("Guglielmo Marconi", 1874, 1937));
+        persons.add(new Person("Michael Faraday", 1791, 1867));
+        persons.add(new Person("Enrico Fermi", 1901, 1954));
+        persons.add(new Person("Johannes Kepler", 1571, 1630));
+        persons.add(new Person("Edwin Schrödinger", 1887, 1961));
+        persons.add(new Person("Werner von Braun", 1912, 1977));
+        persons.add(new Person("Albert Hofmann", 1906, 2008));
+        persons.add(new Person("Robert Oppenheimer", 1904, 1967));
+        persons.add(new Person("Edwin Land", 1909, 1991));
+        persons.add(new Person("Rachel Carson", 1907, 1964));
+        persons.add(new Person("Ernest Rutherford", 1871, 1937));
+        persons.add(new Person("Hans Geiger", 1882, 1945));
+        persons.add(new Person("John Bardeen", 1908, 1991));
+        persons.add(new Person("George Washington Carver", 1864, 1943));
+        return persons;
+    }
+
+    private static void createCollection() {
+        client.dropCollection(DropCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+
+        // Create collection
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(NAME_FIELD)
+                .dataType(DataType.VarChar)
+                .isPrimaryKey(Boolean.TRUE)
+                .maxLength(1024)
+                .enableAnalyzer(true)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(BIRTH_YEAR_FIELD)
+                .dataType(DataType.Int64)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(LIFESPAN_FIELD)
+                .dataType(DataType.Int8)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(SPARSE_VECTOR_FIELD)
+                .dataType(DataType.SparseFloatVector)
+                .build());
+
+        collectionSchema.addFunction(CreateCollectionReq.Function.builder()
+                .functionType(FunctionType.BM25)
+                .name("function_bm25")
+                .inputFieldNames(Collections.singletonList(NAME_FIELD))
+                .outputFieldNames(Collections.singletonList(SPARSE_VECTOR_FIELD))
+                .build());
+
+        List<IndexParam> indexes = new ArrayList<>();
+        indexes.add(IndexParam.builder()
+                .fieldName(SPARSE_VECTOR_FIELD)
+                .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
+                .metricType(IndexParam.MetricType.BM25)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .collectionSchema(collectionSchema)
+                .indexParams(indexes)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .build();
+        client.createCollection(requestCreate);
+        System.out.println("Collection created");
+    }
+
+    private static void insertData() {
+        List<JsonObject> rows = new ArrayList<>();
+        List<Person> data = genData();
+        for (Person person : data) {
+            JsonObject row = new JsonObject();
+            row.addProperty(NAME_FIELD, person.name);
+            row.addProperty(BIRTH_YEAR_FIELD, person.fromYear);
+            row.addProperty(LIFESPAN_FIELD, person.toYear - person.fromYear);
+            rows.add(row);
+        }
+
+        client.insert(InsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(rows)
+                .build());
+        printRowCount();
+    }
+
+    private static void printRowCount() {
+        // Get row count, set ConsistencyLevel.STRONG to sync the data to query node so that data is visible
+        QueryResp countR = client.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .outputFields(Collections.singletonList("count(*)"))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+        System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
+    }
+
+    private static void dropCollection() {
+        client.dropCollection(DropCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("Collection dropped");
+    }
+
+    private static void searchWithRanker(String text, CreateCollectionReq.Function rankerFunction) {
+        System.out.println("\n=============================================================");
+        SearchReq.SearchReqBuilder builder = SearchReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(Collections.singletonList(new EmbeddedText(text)))
+                .limit(100)
+                .outputFields(Arrays.asList(BIRTH_YEAR_FIELD, LIFESPAN_FIELD));
+
+        if (rankerFunction != null) {
+            builder.functionScore(FunctionScore.builder()
+                    .addFunction(rankerFunction)
+                    .build());
+            System.out.printf("Search text '%s' with ranker '%s'\n\n", text, rankerFunction.getName());
+        } else {
+            System.out.printf("Search text '%s' without ranker\n\n", text);
+        }
+
+        // The text is tokenized inside server and turned into a sparse embedding to compare with the vector field
+        SearchResp searchResp = client.search(builder.build());
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        for (List<SearchResp.SearchResult> results : searchResults) {
+            for (SearchResp.SearchResult result : results) {
+                System.out.println(result);
+            }
+        }
+    }
+
+    private static void searchWithoutRanker(String text) {
+        searchWithRanker(text, null);
+    }
+
+    public static void main(String[] args) {
+        createCollection();
+        insertData();
+
+        // Search scientists with name or surname
+        String scientists = "Albert, Charles, Darwin and Edwin";
+        searchWithoutRanker(scientists);
+
+        // Search scientists with name or surname
+        // Rerank the results by linear decay, the scores are rearranged according to the birth years
+        // Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md
+        // The scientist whose birth year is close to 1900 will get a high score
+        DecayRanker decay = DecayRanker.builder()
+                .name("birth_year_linear_decay")
+                .inputFieldNames(Collections.singletonList(BIRTH_YEAR_FIELD))
+                .function("linear")
+                .origin(1900)
+                .scale(50)
+                .offset(0)
+                .decay(0.1)
+                .build();
+        searchWithRanker(scientists, decay);
+
+        // Search scientists with name or surname
+        // Rerank the results by boost, the scores are rearranged according to the birth years
+        // Read the doc for more info: https://milvus.io/docs/boost-ranker.md
+        // The scientist whose lifespan is between 60 and 70 will get a high score
+        BoostRanker boost = BoostRanker.builder()
+                .name("boost_on_lifespan")
+                .filter(String.format("%s > 60 and %s < 70", LIFESPAN_FIELD, LIFESPAN_FIELD))
+                .weight(5.0f)
+                .build();
+        searchWithRanker(scientists, boost);
+
+        dropCollection();
+        client.close();
+    }
+}

+ 5 - 2
sdk-core/src/main/java/io/milvus/v2/common/IndexParam.java

@@ -89,12 +89,15 @@ public class IndexParam {
         // Only for varchar type field
         TRIE("Trie", 100),
 
+        // Only for varchar type field and json_path of JSON field
+        NGRAM(101),
+
         // Only for geometry type field
-        RTREE("RTREE", 120),
+        RTREE(120),
 
         // Only for scalar type field
         STL_SORT(200), // only for numeric type field
-        INVERTED(201), // works for all scalar fields except JSON type field
+        INVERTED(201), // works for all scalar fields and json_path of JSON field
         BITMAP(202), // works for all scalar fields except JSON, FLOAT and DOUBLE type fields
 
         // Only for sparse vectors

+ 89 - 0
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/BoostRanker.java

@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v2.service.vector.request.ranker;
+
+import io.milvus.common.clientenum.FunctionType;
+import io.milvus.common.utils.JsonUtils;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import lombok.experimental.SuperBuilder;
+import org.apache.commons.lang3.StringUtils;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The Decay reranking strategy, which by adjusting search rankings based on numeric field values.
+ * Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md
+ *
+ * Example:
+ * BoostRanker boost = BoostRanker.builder()
+ *                  .name("xxx_boost")
+ *                  .description("boost on xxx")
+ *                  .filter("xxx == 2")
+ *                  .weight(0.5)
+ *                  .randomScoreSeed(123)
+ *                  .randomScoreField("id")
+ *                  .build()
+ *
+ * You also can declare a decay ranker by Function
+ * CreateCollectionReq.Function boost = CreateCollectionReq.Function.builder()
+ *                  .functionType(FunctionType.RERANK)
+ *                  .name("xxx_boost")
+ *                  .description("boost on xxx")
+ *                  .param("reranker", "boost")
+ *                  .param("filter", "xxx == 2")
+ *                  .param("weight", "0.5")
+ *                  .param("random_score", "{\"seed\": 123, \"field\": \"id\"}")
+ *                  .build();
+ */
+@SuperBuilder
+public class BoostRanker extends CreateCollectionReq.Function {
+    private String filter;
+    private Float weight;
+    private Long randomScoreSeed;
+    private String randomScoreField;
+
+    public FunctionType getFunctionType() {
+        return FunctionType.RERANK;
+    }
+
+    public Map<String, String> getParams() {
+        Map<String, String> props = super.getParams();
+        props.put("reranker", "boost");
+        if (!StringUtils.isEmpty(filter)) {
+            props.put("filter", filter);
+        }
+        if (weight != null) {
+            props.put("weight", weight.toString());
+        }
+
+        Map<String, Object> randomScore = new HashMap<>();
+        if (randomScoreSeed != null) {
+            randomScore.put("seed", randomScoreSeed);
+        }
+        if (!StringUtils.isEmpty(randomScoreField)) {
+            randomScore.put("field", randomScoreField);
+        }
+        if (!randomScore.isEmpty()) {
+            props.put("random_score", JsonUtils.toJson(randomScore));
+        }
+        return props;
+    }
+}

+ 24 - 4
sdk-core/src/main/java/io/milvus/v2/service/vector/request/ranker/DecayRanker.java

@@ -30,16 +30,28 @@ import java.util.Map;
  * The Decay reranking strategy, which by adjusting search rankings based on numeric field values.
  * Read the doc for more info: https://milvus.io/docs/decay-ranker-overview.md
  *
- * You also can declare a decay ranker by Function
- * CreateCollectionReq.Function rr = CreateCollectionReq.Function.builder()
+ * Example:
+ * DecayRanker decay = DecayRanker.builder()
+ *                  .name("time_decay")
+ *                  .description("time decay")
+ *                  .inputFieldNames(Collections.singletonList("timestamp"))
+ *                  .function("gauss")
+ *                  .origin(100)
+ *                  .scale(50)
+ *                  .offset(24)
+ *                  .decay(0.5)
+ *                  .build()
+ *
+ * You also can declare a decay ranker by Function:
+ * CreateCollectionReq.Function decay = CreateCollectionReq.Function.builder()
  *                 .functionType(FunctionType.RERANK)
  *                 .name("time_decay")
  *                 .description("time decay")
  *                 .inputFieldNames(Collections.singletonList("timestamp"))
  *                 .param("reranker", "decay")
  *                 .param("function", "gauss")
- *                 .param("origin", "1000")
- *                 .param("scale", "10000")
+ *                 .param("origin", "100")
+ *                 .param("scale", "50")
  *                 .param("offset", "24")
  *                 .param("decay", "0.5")
  *                 .build();
@@ -49,7 +61,9 @@ public class DecayRanker extends CreateCollectionReq.Function {
     @Builder.Default
     private String function = "gauss";
     private Number origin;
+    private Number offset;
     private Number scale;
+    private Number decay;
 
     public FunctionType getFunctionType() {
         return FunctionType.RERANK;
@@ -63,9 +77,15 @@ public class DecayRanker extends CreateCollectionReq.Function {
         if (origin != null) {
             props.put("origin", origin.toString());
         }
+        if (offset != null) {
+            props.put("offset", offset.toString());
+        }
         if (scale != null) {
             props.put("scale", scale.toString());
         }
+        if (decay != null) {
+            props.put("decay", decay.toString());
+        }
         return props;
     }
 }