Browse Source

Merge pull request #147 from sahuang/0.9.2

Use schema for examples
Xiaohai Xu 4 years ago
parent
commit
4082af875b

+ 1 - 9
examples/src/main/java/MilvusBasicExample.java

@@ -19,15 +19,7 @@
 
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
-import io.milvus.client.CollectionMapping;
-import io.milvus.client.CompactParam;
-import io.milvus.client.ConnectParam;
-import io.milvus.client.DataType;
-import io.milvus.client.InsertParam;
-import io.milvus.client.MilvusClient;
-import io.milvus.client.MilvusGrpcClient;
-import io.milvus.client.SearchParam;
-import io.milvus.client.SearchResult;
+import io.milvus.client.*;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;

+ 36 - 44
examples/src/main/java/MilvusIndexExample.java

@@ -17,18 +17,10 @@
  * under the License.
  */
 
-import io.milvus.client.CollectionMapping;
-import io.milvus.client.ConnectParam;
-import io.milvus.client.DataType;
-import io.milvus.client.Index;
-import io.milvus.client.IndexType;
-import io.milvus.client.InsertParam;
-import io.milvus.client.JsonBuilder;
-import io.milvus.client.MetricType;
-import io.milvus.client.MilvusClient;
-import io.milvus.client.MilvusGrpcClient;
-import io.milvus.client.SearchParam;
-import io.milvus.client.SearchResult;
+import io.milvus.client.*;
+import io.milvus.client.dsl.MilvusService;
+import io.milvus.client.dsl.Query;
+import io.milvus.client.dsl.Schema;
 import java.io.BufferedReader;
 import java.io.FileReader;
 import java.io.IOException;
@@ -54,6 +46,8 @@ import org.json.JSONObject;
  */
 public class MilvusIndexExample {
 
+  public static final int dimension = 8;
+
   // Helper function that generates random float vectors
   private static List<List<Float>> randomFloatVectors(int vectorCount, int dimension) {
     SplittableRandom splitCollectionRandom = new SplittableRandom();
@@ -87,15 +81,15 @@ public class MilvusIndexExample {
       client.dropCollection(collectionName);
     }
 
-    // Create collection
-    final int dimension = 8;
-    CollectionMapping collectionMapping =
-        CollectionMapping.create(collectionName)
-            .addField("release_year", DataType.INT64)
-            .addVectorField("embedding", DataType.VECTOR_FLOAT, dimension)
-            .setParamsInJson("{\"segment_row_limit\": 4096, \"auto_id\": false}");
-
-    client.createCollection(collectionMapping);
+    /*
+     * Basic create collection:
+     *   Another way to create a collection is to predefine a schema.
+     *   The schema can later be used in querying.
+     */
+    FilmSchema filmSchema = new FilmSchema();
+    MilvusService service = new MilvusService(client, collectionName, filmSchema);
+    service.createCollection(
+        new JsonBuilder().param("auto_id", false).param("segment_row_limit", 4096).build());
 
     /*
      * Basic insert and create index:
@@ -110,7 +104,7 @@ public class MilvusIndexExample {
     BufferedReader csvReader = new BufferedReader(new FileReader(path));
     List<Long> ids = new ArrayList<>();
     List<String> titles = new ArrayList<>();
-    List<Long> releaseYears = new ArrayList<>();
+    List<Integer> releaseYears = new ArrayList<>();
     List<List<Float>> embeddings = new ArrayList<>();
     String row;
     while ((row = csvReader.readLine()) != null) {
@@ -118,7 +112,7 @@ public class MilvusIndexExample {
       // process four columns in order
       ids.add(Long.parseLong(data[0]));
       titles.add(data[1]);
-      releaseYears.add(Long.parseLong(data[2]));
+      releaseYears.add(Integer.parseInt(data[2]));
       List<Float> embedding = new ArrayList<>(dimension);
       for (int i = 3; i < 11; i++) {
         // 8 float values in a vector
@@ -137,7 +131,7 @@ public class MilvusIndexExample {
     // Now we can insert entities, the total row count should be 8657.
     InsertParam insertParam =
         InsertParam.create(collectionName)
-            .addField("release_year", DataType.INT64, releaseYears)
+            .addField("release_year", DataType.INT32, releaseYears)
             .addVectorField("embedding", DataType.VECTOR_FLOAT, embeddings)
             .setEntityIds(ids);
 
@@ -176,29 +170,21 @@ public class MilvusIndexExample {
      *
      *   Based on the index you created, the available search parameters will be different. Refer to
      *   Milvus documentation for how to set the optimal parameters based on your needs.
+     *
+     *   Here we present a way to use predefined schema to search vectors.
      */
     List<List<Float>> queryEmbedding = randomFloatVectors(1, dimension);
-    final long topK = 3;
-    String dsl =
-        String.format(
-            "{\"bool\": {"
-                + "\"must\": [{"
-                + "    \"term\": {"
-                + "        \"release_year\": [2002, 1995]"
-                + "    }},{"
-                + "    \"vector\": {"
-                + "        \"embedding\": {"
-                + "            \"topk\": %d, \"metric_type\": \"L2\", \"type\": \"float\", \"query\": "
-                + "%s, \"params\": {\"nprobe\": 8}"
-                + "    }}}]}}",
-            topK, queryEmbedding.toString());
-
-    SearchParam searchParam =
-        SearchParam.create(collectionName)
-            .setDsl(dsl)
-            .setParamsInJson("{\"fields\": [\"release_year\", \"embedding\"]}");
+    final int topK = 3;
+    Query query = Query.bool(Query.must(
+        filmSchema.releaseYear.in(1995, 2002),
+        filmSchema.embedding.query(queryEmbedding)
+            .metricType(MetricType.L2)
+            .top(topK)
+            .param("nprobe", 8)));
+    SearchParam searchParam = service.buildSearchParam(query)
+        .setParamsInJson("{\"fields\": [\"release_year\", \"embedding\"]}");
     System.out.println("\n--------Search Result--------");
-    SearchResult searchResult = client.search(searchParam);
+    SearchResult searchResult = service.search(searchParam);
     System.out.println("- ids: " + searchResult.getResultIdsList().toString());
     System.out.println("- distances: " + searchResult.getResultDistancesList().toString());
     for (List<Map<String, Object>> singleQueryResult : searchResult.getFieldsMap()) {
@@ -226,4 +212,10 @@ public class MilvusIndexExample {
     // Close connection
     client.close();
   }
+
+  // Schema that defines a collection
+  private static class FilmSchema extends Schema {
+    public final Int32Field releaseYear = new Int32Field("release_year");
+    public final FloatVectorField embedding = new FloatVectorField("embedding", dimension);
+  }
 }